In [1]:
import numpy as np
import open3d as o3d
import transforms3d as t3d
from probreg import callbacks
import utils
import plotly.graph_objects as go


Import the data and load it as source and template. In this case, we load the dta from text and downsample it using a voxel size of 0.1 units

In [4]:
source, target = utils.prepare_source_and_target_nonrigid_3d('face-x.txt',
                                                             'face-y.txt', 1.0)
source_np = np.asarray(source.points)
target_np = np.asarray(target.points)

fig = go.Figure(data=[go.Scatter3d(x=source_np[:,0], y=source_np[:,1], z=source_np[:,2],
                                   mode='markers',    marker=dict( size=1, color='green',  opacity=1.0)), 
                                   go.Scatter3d(x=target_np[:,0], y=target_np[:,1], z=target_np[:,2],
                                   mode='markers',    marker=dict( size=1, color='red',  opacity=1.0)) ])
fig.show()

PointCloud with 22958 points.
PointCloud with 22990 points.


In [None]:
EstepResult = namedtuple('EstepResult', ['pt1', 'p1', 'px', 'n_p'])
MstepResult = namedtuple('MstepResult', ['transformation', 'sigma2', 'q'])


@six.add_metaclass(abc.ABCMeta)
class CoherentPointDrift():
    """Coherent Point Drift algorithm.
    This is an abstract class.
    Based on this class, it is inherited by rigid, affine, nonrigid classes
    according to the type of transformation.
    In this class, Estimation step in EM algorithm is implemented and
    Maximazation step is implemented in the inherited classes.

    Args:
        source (numpy.ndarray, optional): Source point cloud data.
        use_cuda (bool, optional): Use CUDA.
    """
    def __init__(self, source=None, use_cuda=False):
        self._source = source
        self._tf_type = None
        self._callbacks = []
        if use_cuda:
            import cupy as cp
            from . import cupy_utils
            self.xp = cp
            self.cupy_utils = cupy_utils
            self._squared_kernel_sum = cupy_utils.squared_kernel_sum
        else:
            self.xp = np
            self._squared_kernel_sum = mu.squared_kernel_sum

    def set_source(self, source):
        self._source = source

    def set_callbacks(self, callbacks):
        self._callbacks.extend(callbacks)

    @abc.abstractmethod
    def _initialize(self, target):
        return MstepResult(None, None, None)

    def expectation_step(self, t_source, target, sigma2, w=0.0):
        """Expectation step for CPD
        """
        assert t_source.ndim == 2 and target.ndim == 2, "source and target must have 2 dimensions."
        pmat = self.xp.stack([self.xp.sum(self.xp.square(target - ts), axis=1) for ts in t_source])
        pmat = self.xp.exp(-pmat / (2.0 * sigma2))

        c = (2.0 * np.pi * sigma2) ** (t_source.shape[1] * 0.5)
        c *= w / (1.0 - w) * t_source.shape[0] / target.shape[0]
        den = self.xp.sum(pmat, axis=0)
        den[den==0] = self.xp.finfo(np.float32).eps
        den += c

        pmat  = self.xp.divide(pmat, den)
        pt1 = self.xp.sum(pmat, axis=0)
        p1  = self.xp.sum(pmat, axis=1)
        px = self.xp.dot(pmat, target)
        return EstepResult(pt1, p1, px, np.sum(p1))

    def maximization_step(self, target, estep_res, sigma2_p=None):
        return self._maximization_step(self._source, target, estep_res, sigma2_p, xp=self.xp)

    @staticmethod
    @abc.abstractmethod
    def _maximization_step(source, target, estep_res, sigma2_p=None, xp=np):
        return None

    def registration(self, target, w=0.0,
                     maxiter=50, tol=0.001):
        assert not self._tf_type is None, "transformation type is None."
        res = self._initialize(target)
        q = res.q
        for i in range(maxiter):
            t_source = res.transformation.transform(self._source)
            estep_res = self.expectation_step(t_source, target, res.sigma2, w)
            res = self.maximization_step(target, estep_res, res.sigma2)
            for c in self._callbacks:
                c(res.transformation)
            log.debug("Iteration: {}, Criteria: {}".format(i, res.q))
            if abs(res.q - q) < tol:
                break
            q = res.q
        return res