In [None]:
import os
import scipy
import numpy as np
import trimesh
from trimesh import viewer
from trimesh.transformations import translation_matrix, rotation_matrix, concatenate_matrices
from scipy.spatial.transform import Rotation
from scipy.spatial import cKDTree

np.set_printoptions(precision=3, suppress=True, linewidth=np.inf)


In [None]:
mesh1 = trimesh.exchange.load.load('wolf1.ply')
mesh2 = trimesh.exchange.load.load('wolf2.ply')

# apply an arbitrary transformation to mesh1
# T = translation_matrix([-11, 0, 11])
# xaxis, yaxis, zaxis = [1, 0, 0], [0, 1, 0], [0, 0, 1]
# Rx = rotation_matrix(0, xaxis)
# Ry = rotation_matrix(-np.pi/6, yaxis)
# Rz = rotation_matrix(0, zaxis)
# M = concatenate_matrices(T, Rx, Ry, Rz)
# mesh1.apply_transform(M)

# write scene to html
if os.path.exists("scene.html"):
    os.remove("scene.html")
html = viewer.notebook.scene_to_html((trimesh.util.concatenate([mesh1, mesh2])).scene())
with open("scene.html", "w") as file:
    file.write(html)


In [None]:
def get_laplacian(faces):
    N = faces.max() + 1
    # face to edge
    row, col = np.concatenate([faces[:2], faces[1:], faces[::2]], axis=1)
    # to undirected
    row, col = np.unique(np.concatenate([np.stack([row, col]), np.stack([col, row])], axis=1), axis=1)
    # edge weight
    data = np.ones(len(row))
    # adjacency matrix
    A = scipy.sparse.coo_matrix((data, (row, col)), shape=(N, N))
    # ensure symmetry
    assert (abs(A - A.T) > 1e-10).nnz == 0
    # degree matrix
    D = scipy.sparse.diags(np.array(A.sum(axis=1)).flatten())
    # combinatorial Laplacian
    L = D - A
    return L.tocoo()

def register(source, target, w1=0.1, w2=1.0, w3=1.0, max_iter=100):
    x = source.vertices
    y = target.vertices
    N = len(x)
    L = get_laplacian(source.faces.T)
    z = x.copy()
    transform = (np.eye(3), np.zeros(3))
    x = source.vertices
    tree = cKDTree(y)
    for iter in range(max_iter):
        x_t = z
        # build a matrix equivalent to cross product
        X_t = np.stack([
            np.stack((np.zeros(N), -x_t[:, 2], x_t[:, 1]), axis=1),
            np.stack((x_t[:, 2], np.zeros(N), -x_t[:, 0]), axis=1),
            np.stack((-x_t[:, 1], x_t[:, 0], np.zeros(N)), axis=1)
        ], axis=1)
        # build blocks of sparse coefficient matrix A
        # helper expressions
        LX_t = L.dot(X_t.reshape(N, 9)).reshape(N, 3, 3)
        LXX_t = L.dot(np.einsum('abc,ace->abe', X_t, X_t).reshape(N, 9)).reshape(N, 3, 3)
        off_diag = scipy.sparse.bsr_matrix((X_t[L.row] - X_t[L.col],
                                            L.tocsr().indices,
                                            L.tocsr().indptr),
                                            shape=(3*N, 3*N)).tocsr()
        diag = scipy.sparse.block_diag(LX_t).tocsr()
        # derivatives w.r.t. r
        A_rr = -np.einsum('abc,acd->bd', X_t, X_t)
        A_tr = np.sum(X_t, axis=0)
        A_rir = scipy.sparse.coo_matrix((3, 3*N)) # all zeros
        A_zir = -np.concatenate([X_ti for X_ti in X_t], axis=1)
        # derivatives w.r.t. t
        A_rt = -np.sum(X_t, axis=0)
        A_tt = N * np.eye(3)
        A_rit = scipy.sparse.coo_matrix((3, 3*N)) # all zeros
        A_zit = -np.tile(np.eye(3), (1, N))
        # derivatives w.r.t. r_j
        A_rrj = scipy.sparse.coo_matrix((3*N, 3)) # all zeros
        A_trj = scipy.sparse.coo_matrix((3*N, 3)) # all zeros
        A_rirj = scipy.sparse.block_diag(LXX_t - np.einsum('acd,ade->ace', LX_t, X_t) - np.einsum('ade,acd->ace', LX_t, X_t)).tocsr()
        A_zirj = (off_diag - diag)
        # derivatives w.r.t. z_j
        A_rzj = np.concatenate([X_ti for X_ti in X_t], axis=0) * w2
        A_tzj = -np.tile(np.eye(3), (N, 1)) * w2
        A_rizj = (off_diag + diag) * w3
        A_zizj = scipy.sparse.eye(3 * N) * (w1 + w2) + 2 * scipy.sparse.kron(L, scipy.sparse.eye(3), format='csr') * w3 
        A_rr = scipy.sparse.csr_matrix(A_rr)
        A_tr = scipy.sparse.csr_matrix(A_tr)
        A_rir = scipy.sparse.csr_matrix(A_rir)
        A_zir = scipy.sparse.csr_matrix(A_zir)
        A_rt = scipy.sparse.csr_matrix(A_rt)
        A_tt = scipy.sparse.csr_matrix(A_tt)
        A_rit = scipy.sparse.csr_matrix(A_rit)
        A_zit = scipy.sparse.csr_matrix(A_zit)
        A_rrj = scipy.sparse.csr_matrix(A_rrj)
        A_trj = scipy.sparse.csr_matrix(A_trj)
        A_rirj = scipy.sparse.csr_matrix(A_rirj)
        A_zirj = scipy.sparse.csr_matrix(A_zirj)
        A_rzj = scipy.sparse.csr_matrix(A_rzj)
        A_tzj = scipy.sparse.csr_matrix(A_tzj)
        A_rizj = scipy.sparse.csr_matrix(A_rizj)
        A_zizj = scipy.sparse.csr_matrix(A_zizj)
        # assemble coefficient matrix A
        A = scipy.sparse.bmat([[A_rr, A_tr, A_rir, A_zir],
                               [A_rt, A_tt, A_rit, A_zit],
                               [A_rrj, A_trj, A_rirj, A_zirj],
                               [A_rzj, A_tzj, A_rizj, A_zizj]], format='csr')
        # build blocks of result vector b
        b_r = np.zeros(3)
        b_t = -np.sum(x_t, axis=0)
        _, PI = tree.query(z, k=1)
        b_ri = np.zeros(3 * N)
        b_zi = (w1 * y[PI] + w2 * x_t + w3 * 2 * L.dot(x_t)).flatten()
        # assemble result vector b
        b = np.concatenate([b_r, b_t, b_ri, b_zi])
        # solve system of linear equations Ax = b
        # A is not symmetric, so the conjugate gradient method might be unstable
        solution, info = scipy.sparse.linalg.cg(A, b, tol=1e-5, maxiter=1000)
        print('info: ', info)
        # update resulting transformation
        r = solution[:3]
        t = solution[3:6]
        z = solution[6+3*N:].reshape(-1, 3)
        R = Rotation.from_euler('xyz', r).as_matrix()
        # ensure that R is a rotation matrix
        assert np.isclose(np.linalg.det(R), 1.0)
        transform = (np.dot(R, transform[0]), t + transform[1])
        print(iter, '\n', transform, '\n')
    return transform, z


In [None]:
(R, t), z = register(mesh1, mesh2, max_iter=100)

In [None]:
trans_mesh = trimesh.base.Trimesh(z, mesh1.faces) # z is also interesting

trans_mesh.visual.vertex_colors = trimesh.visual.random_color()

if os.path.exists("scene.html"):
    os.remove("scene.html")
html = viewer.notebook.scene_to_html(trimesh.util.concatenate([trans_mesh, mesh2]).scene())
with open("scene.html", "w") as file:
    file.write(html)
