## Description taken from Omar Todd's Report

Remeshing is about creating a new mesh that is a new discretization of the same underlying continuous shape. Remeshing is a non-trivial operation, as well as an umbrella word for a variety of slightly different tasks.<br> 

For instance, one may have meshes of poor quality (some faces might be extremely 'non regular', e.g. very much non equilateral triangles with an obtuse angle; which can significantly degrade the quality of numerical computations). In that case one may want to create a new mesh/discretization at a different or identical resolution but also with higher quality.<br>

On the contrary, one may be happy with the current mesh quality, but want to create a finer mesh subdivision from the existing mesh. As in, a mesh where all current vertices are preserved and some are sometimes added (e.g. at face barycenters and at the middle of edges). The advantage is to make it easier to transport data from the original mesh to the new one.<br>

Similar questions apply in the context of mesh decimation. Moreover, it is generally non-trivial to specify the exact number of nodes in the new mesh (only a rough requirement can be passed).<br>

The example of remeshing below is meant to increase the mesh resolution (number of vertices), to make it similar across all meshes in the dataset, and to fix mesh quality where needed. As a drawback, it is relatively expensive, and makes remapping pre-existing data from one mesh to another non-trivial. Therefore it is best to use this algorithm in a preprocessing step, once and for all, rather than as part of a fine to coarse ML/DL architecture.<br>

[edit. The algorithm has changed somewhat. I have to review the behaviour a bit. On the bright side, it seems to be able to output a number of nodes as specified, and to be much faster. On the downside, I had to adjust my wrapper to get similar quality.]

## My Preprocessing

In [1]:
# you need importlib
%load_ext autoreload
%autoreload 2

In [2]:
ROOT_PATH = '../../../../'

import sys

sys.path.append(ROOT_PATH)

In [3]:
from importlib import reload

In [4]:
import numpy as np
import os

import torch
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import plot_scene

import meshtools

In [5]:
from meshtools import polydata as pd
from meshtools.polydata import io as vtkIO
from meshtools.polydata import remeshing

from biobank import io as bb

In [6]:
# root_dir = r'C:\Users\Loic\Documents\Data\BB\12579\brain\shapes'
root_dir = '/vol/biomedic3/bglocker/brainshapes/'  # 1000596/'
# remesh_dir = r'C:\Users\Loic\Documents\Projects\meng-omar\DATA\BB'

In [7]:
# Create the polydata dataset and metadata
# structures = bb.generate_structures()
dataset_filenames = vtkIO.generate_dataset_filenames(root_dir, {'br_stem': 'BrStem'})

In [8]:
_dataset_filenames = dataset_filenames[:30]

In [10]:
subject_ids = bb.generate_subject_ids(root_dir, {'br_stem': 'BrStem'})

In [11]:
subject_dataset = bb.read_subject_polydatas(_dataset_filenames, root_dir)

In [12]:
dataset = bb.generate_data(subject_dataset)
data_ids = bb.generate_data_ids(dataset_filenames, subject_ids)

In [13]:
len(dataset), len(subject_dataset), len(data_ids), len(subject_ids)

(30, 30, 14502, 14502)

In [14]:
# Prepare for remeshing: set up a mesh whose vertex density will be used as reference
triangles = 2000
target_mesh = subject_dataset[0]['br_stem']
remesher = remeshing.Remesher()
# remesher.set_num_points_per_unit_area_to_target(target_mesh)
target_mesh = remesher.remesh(target_mesh, nclus=triangles, nsubdivide=5)
print(target_mesh.points.shape, target_mesh.faces.shape)

(2000, 3) (15984,)


In [15]:
target_mesh.face_normals.shape, target_mesh.center_of_mass(), target_mesh.center

((3996, 3),
 array([ 92.81333829, 109.1403265 ,  79.74146076]),
 [92.22482498758211, 108.5850091437623, 76.48119261265295])

In [16]:
source_mesh = subject_dataset[20]['br_stem']
source_verts = np.copy(source_mesh.points)
source_faces = np.copy(source_mesh.faces).reshape(-1, 4)[:, 1:]
print(f'Source Mesh: faces {source_faces.shape} verts: {source_verts.shape}')

target_verts = np.copy(target_mesh.points)
target_faces = np.copy(target_mesh.faces).reshape(-1, 4)[:, 1:]

remeshed_source = remesher.remesh(source_mesh, nclus=triangles)  # source_mesh.points.shape[0])
remeshed_source_verts = np.copy(remeshed_source.points)
remeshed_source_faces = np.copy(remeshed_source.faces).reshape(-1, 4)[:, 1:]
# print(remeshed_source_faces.shape, remeshed_source_verts.shape)
print(f'Remeshed: faces {remeshed_source_faces.shape} verts: {remeshed_source_verts.shape}')

mesh = Meshes(
    verts=[torch.tensor(target_verts), torch.tensor(source_verts), torch.tensor(remeshed_source_verts),],
    faces=[torch.tensor(target_faces), torch.tensor(source_faces), torch.tensor(remeshed_source_faces),],
)

fig = plot_scene({
    f"Left Thalamus": {
        f"target, Volume: {target_mesh.volume}": mesh[0],
        f"source, Volume: {source_mesh.volume}": mesh[1],
        f"remesh, Volume: {remeshed_source.volume}": mesh[2],
    },
})
fig.show()

Source Mesh: faces (1280, 3) verts: (642, 3)
Remeshed: faces (3996, 3) verts: (2000, 3)


In [17]:
from sklearn.neighbors import NearestNeighbors

def chamfer_distance(x, y, metric='l2', direction='bi'):
    """
    https://gist.github.com/sergeyprokudin/c4bf4059230da8db8256e36524993367
    
    Chamfer distance between two point clouds
    Parameters
    ----------
    x: numpy array [n_points_x, n_dims]
        first point cloud
    y: numpy array [n_points_y, n_dims]
        second point cloud
    metric: string or callable, default ‘l2’
        metric to use for distance computation. Any metric from scikit-learn or scipy.spatial.distance can be used.
    direction: str
        direction of Chamfer distance.
            'y_to_x':  computes average minimal distance from every point in y to x
            'x_to_y':  computes average minimal distance from every point in x to y
            'bi': compute both
    Returns
    -------
    chamfer_dist: float
        computed bidirectional Chamfer distance:
            sum_{x_i \in x}{\min_{y_j \in y}{||x_i-y_j||**2}} + sum_{y_j \in y}{\min_{x_i \in x}{||x_i-y_j||**2}}
    """
    
    if direction=='y_to_x':
        x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x)
        min_y_to_x = x_nn.kneighbors(y)[0]
        chamfer_dist = np.mean(min_y_to_x)
    elif direction=='x_to_y':
        y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y)
        min_x_to_y = y_nn.kneighbors(x)[0]
        chamfer_dist = np.mean(min_x_to_y)
    elif direction=='bi':
        x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x)
        min_y_to_x = x_nn.kneighbors(y)[0]
        y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y)
        min_x_to_y = y_nn.kneighbors(x)[0]
        chamfer_dist = np.mean(min_y_to_x) + np.mean(min_x_to_y)
    else:
        raise ValueError("Invalid direction type. Supported types: \'y_x\', \'x_y\', \'bi\'")
        
    return chamfer_dist

In [18]:
from typing import Tuple, List
from scipy.spatial import KDTree


class RigidRegistration:
    
    def __init__(self, fixed_image: np.ndarray):
        self.fixed_image = fixed_image
        self.fixed_image_kd_tree = KDTree(fixed_image)
        self.fixed_mean_centered, self.fixed_mean = self.mean_centering(fixed_image)
        self.fixed_vertices = fixed_image.shape[0]
        self.fixed_dim = fixed_image.shape[1]
        
    def get_fixed_mean_centering(self) -> Tuple[np.ndarray, np.ndarray]:
        return self.fixed_mean_centered, self.fixed_mean
        
    def mean_centering(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        # TODO: Accept a density weighting for each pixel in the image
        mean = np.mean(image, axis=0)
        mean_centering = image - mean
        return mean_centering, mean
    
    def calc_rotation_matrix(self, moving_image_mean_centered: np.ndarray, moving_mean: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        https://johnwlambert.github.io/icp/

        2-D or 3-D registration with known correspondences.
        Registration occurs in the zero centered coordinate system, and then
        must be transported back.
        
        Kabsch Algorithm: https://en.wikipedia.org/wiki/Kabsch_algorithm
        
        Args:
            moving_image: Mean centered array of shape (N, D) -- Point Cloud to Align (source)

        Returns:
            R: optimal rotation (D, D)
            t: optimal translation (D, )
        """
        assert moving_image_mean_centered.shape == (self.fixed_vertices, self.fixed_dim)
        
        cross_cov = moving_image_mean_centered.T @ self.fixed_mean_centered
        u, _, v_t = np.linalg.svd(cross_cov)
        
        # Check for reflection case
        s = np.eye(self.fixed_dim)
        det = np.linalg.det(u) * np.linalg.det(v_t.T)
        if not np.isclose(det, 1.):
            s[self.fixed_dim - 1, self.fixed_dim - 1] = -1
        
        r = u @ s @ v_t
        t = self.fixed_mean - moving_mean @ r
        
        return r, t
    
    def manual_rotations(self, moving_image: np.ndarray) -> np.ndarray:
        for theta in range(5, 360, 5):
            radians = np.radians(theta)
            rotate = np.array([
                [math.cos(radians), 0, -math.sin(radians)],
                [0, 1, 0],
                [math.sin(radians), 0, math.cos(radians)],
            ])
            moving_image_rotated = moving_image @ rotate
            knn_dist, l2_dist = self.calc_error(moving_image_rotated)
            print(theta, knn_dist, l2_dist)

        return moving_image_rotated
    
    def apply(self, moving_image_mean_centered: np.ndarray, rotate: np.ndarray, translate: np.ndarray) -> np.ndarray:
        return moving_image_mean_centered @ rotate + translate
    
    def calc_error(self, moving_image: np.ndarray, knn_bi_dir: bool = False):
        # TODO: Could replace knn_dist with chamfer dist
        knn_dist = self.fixed_image_kd_tree.query(moving_image)[0].mean()
        if knn_bi_dir:
            # Make KD tree and find nn in opposite direction
            # Calculate mean knn_dist
            pass
        l2_dist = np.linalg.norm(moving_image - self.fixed_image)
        return knn_dist, l2_dist
    
    def align(self, moving_image: np.ndarray, moving_image_faces: np.ndarray = None, n_iter: int = 1, eps: float = 1e-2) -> np.ndarray:
        knn_dist, l2_dist = self.calc_error(moving_image)
        print(knn_dist, l2_dist)
        if knn_dist < eps:
            return moving_image
        
        for _ in range(n_iter):
            moving_image_mean_centered, moving_mean = self.mean_centering(moving_image)
            r, t = self.calc_rotation_matrix(moving_image_mean_centered, moving_mean)
            moving_image = self.apply(moving_image, r, t)
            knn_dist, l2_dist = self.calc_error(moving_image)
            print(knn_dist, l2_dist)
            if knn_dist < eps:
                return moving_image
        
        # moving_image = self.manual_rotations(moving_image)
        
        return moving_image

In [19]:
reg = RigidRegistration(target_verts)  # , target_mesh.faces)
optimal_source_verts = reg.align(remeshed_source_verts, n_iter=1)

6.4681542554427995 1292.4369733405872
1.2521449280598547 1096.889454748822


In [20]:
mean_centered_fixed, _ = reg.get_fixed_mean_centering()
mean_centered_moving, _ = reg.mean_centering(remeshed_source_verts)

best_mesh = Meshes(
    verts=[
        torch.tensor(target_verts),
        torch.tensor(remeshed_source_verts),
        torch.Tensor(mean_centered_fixed),
        torch.Tensor(mean_centered_moving),
        torch.Tensor(optimal_source_verts),
    ],
    faces=[
        torch.tensor(target_faces),
        torch.tensor(remeshed_source_faces),
        torch.Tensor(target_faces),
        torch.Tensor(remeshed_source_faces),
        torch.Tensor(remeshed_source_faces),
    ],
)

fig = plot_scene({
    f"All Orientations": {
        f"target": best_mesh[0],
        f"remeshed": best_mesh[1],
    },
    f"Both mean centered": {
        f"target": best_mesh[2],
        f"best": best_mesh[3],
    },
    f"Centered + Rotated": {
        f"target": best_mesh[0],
        f"best": best_mesh[4],
    },
}, ncols=3)
fig.show()

Todo: The cortex will not align well when there are large vertices?

In [24]:
from typing import Tuple, List
import matplotlib.tri as mtri
import matplotlib.pyplot as plt


def plot_wireframe_and_meshes(
    vertices: np.ndarray,
    pred_verts: np.ndarray,
    triangles: np.ndarray,
    figsize: Tuple[int, int] = (20, 15),
    elevations: List[int] = [0],
    azimuths: int = 5,
    alpha: float = 0.8,
    wireframe_alpha: float = 0.0,
):
    triang = mtri.Triangulation(vertices[:, 0], vertices[:, 1], triangles=triangles)
    z = vertices[:, 2].flatten()
    
    triang_pred = mtri.Triangulation(pred_verts[:, 0], pred_verts[:, 1], triangles=triangles)
    pred_z = pred_verts[:, 2].flatten()

    nrows = len(elevations)
    ncols = azimuths
    fig, ax = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=figsize,
        subplot_kw=dict(projection="3d"),
    )
    ax = ax.reshape(nrows, ncols)

    azimuth_intervals = 360 / ncols
    elevation_intervals = 360 / nrows

    for j, elevation in enumerate(elevations):
        for i in range(ncols):
            azimuth = azimuth_intervals * i
            ax[j][i].set_title(f'E: {int(elevation)}, A: {int(azimuth)}')
            ax[j][i].view_init(elevation, azimuth)
            ax[j][i].plot_trisurf(triang, z, edgecolor='grey', alpha=alpha)
            ax[j][i].plot_trisurf(triang_pred, pred_z, edgecolor='lightpink', alpha=0.0)
            ax[j][i].set_xlabel('x')
            ax[j][i].set_ylabel('y')
            ax[j][i].set_zlabel('z')

    plt.show()

In [26]:
target_verts.shape, optimal_source_verts.shape

((2000, 3), (2000, 3))

In [29]:
subject_dataset[0]['br_stem'].verts

array([], dtype=int64)

In [27]:
plot_wireframe_and_meshes(target_verts, subject_dataset[0]['br_stem'].pos, target_faces, azimuths=4)

AttributeError: 'PolyData' object has no attribute 'pos'

In [None]:
plot_wireframe_and_meshes(target_verts, optimal_source_verts, target_faces, azimuths=4)