# Surface Simplification Using Quadric Error Metrics

Implementation of the surface simplification algorithm described in the paper [Surface Simplification Using Quadric Error Metrics](https://www.cs.cmu.edu/~./garland/Papers/quadrics.pdf) by Michael Garland and Paul S. Heckbert.

### Algorithm Summary
1. Compute the quadric error matrix, $Q$, for each vertex.
2. Select all valid pairs.
3. Compute the optimal target vertex, $\bar{v}$ for each pair. $\bar{v}^T (Q_1 + Q_2) \bar{v}$ is the cost of this pair.
4. Find the pair with the lowest cost and collapse it.  This can be done with a heap to speed up the process.

### Running
Either run this locally or view the colab notebook at https://colab.research.google.com/drive/1T1KKc6WKizQxXjPQqpfnOJuFGW46JXe2?usp=sharing.

### Local Setup
Enter the project root directory and run

`python3 -m venv venv`

`source venv/bin/activate`

`pip3 install -r requirements.txt`

If using `jupyter notebook`, register the kernel with:

`python -m ipykernel install --user --name=mesh_simplify`

and select it as the activate kernel.  

Otherwise, use vscode and select venv as the python interpreter.

In [323]:
from collections import namedtuple
import numpy as np
import plotly.graph_objects as go


# MESH_FNAME = "../assets/bunny_1k.obj"
MESH_FNAME = "../assets/cube.obj"
# MESH_FNAME = "../assets/Model1.obj"

### Load Starting Mesh

In [324]:
ModVertex = namedtuple("ModVertex", ["coord", "parent_idx"])


class ModMesh:
    def __init__(self, vertices: np.ndarray, faces: np.ndarray):
        """
        Class for storing and simplifying a mesh.
        :param vertices: (n, 3) array of vertex coordinates
        :param faces: (m, 3) array of vertex indices for each face
        """
        self.vertices = {}  # vertex index -> vertex coordinate or vertex_idx
        for i in range(len(vertices)):
            # Can also repr as np array with parent_idx as the last element
            self.vertices[i] = ModVertex(vertices[i], -1)

        self.n_vertices = len(vertices)

        self.faces = faces.copy()

    def __len__(self):
        return self.n_vertices

    def get_vertex(self, index):
        """
        Get coordinate of vertex at index.

        :param index: index of vertex
        :return: (3, ) coordinate of vertex, index of root vertex
        """
        # Recursively check if we are at a root vertex.  If not,
        # keep going until we find the root vertex.  This state occurs when
        # a vertex has been merged with another vertex.
        val = self.vertices[index]
        if val.parent_idx != -1:
            return self.get_vertex(val.parent_idx)

        return val.coord, index

    def combine_vertices(self, v1_idx, v2_idx, v_bar):
        """
        Combine two vertices into one new vertex.

        Assume vbar is 4x1 homogeneous coordinate.
        """
        v1_coord, v1_root_idx = self.get_vertex(v1_idx)
        v2_coord, v2_root_idx = self.get_vertex(v2_idx)

        # If the vertices are already the same, do nothing
        # This happens when a prior merge causes vertices to merge.
        if v1_root_idx == v2_root_idx:
            return

        v_bar = v_bar.flatten()
        v_bar = v_bar[:3] / v_bar[3]

        self.vertices[v1_root_idx] = ModVertex(v_bar, -1)

        self.vertices[v2_root_idx] = ModVertex(
            np.array([None, None, None]), v1_root_idx
        )

        self.n_vertices -= 1

    def get_vertices_and_faces(self, reduce=False):
        """
        Get vertices and faces of mesh.

        :param reduce: whether to update the internal representation of the
            mesh to the reduced version.
        :return: (vertices, faces) tuple of numpy arrays
        """
        reduced_vertices, new_faces = self._reduce(set_reduced=reduce)
        vertices = []
        for v in reduced_vertices.values():
            vertices.append(v.coord)
        return np.array(vertices), new_faces.copy()

    def _reduce(self, set_reduced=False):
        """
        Reduce faces to only use root vertices.  Condense vertices to only
        include root vertices.
        """
        # 1. Construct new vertex list of only roots
        # 2. Define map from old root vertices to new roots
        # 3. Update faces to use new roots

        # Construct new vertex list of only roots and define map
        reduced_vertices = {}
        old_to_new_map = {}
        vert_idx = 0
        for i in range(len(self.vertices)):
            coord, root_idx = self.get_vertex(i)
            if i == root_idx:
                reduced_vertices[vert_idx] = ModVertex(coord, -1)
                old_to_new_map[i] = vert_idx
                vert_idx += 1

        # Update faces to use new roots
        new_faces = []
        for face in self.faces:
            new_face = []
            for i in range(3):
                new_face.append(old_to_new_map[self.get_vertex(face[i])[1]])
            new_faces.append(new_face)

        new_faces = np.array(new_faces)

        if set_reduced:
            self.vertices = reduced_vertices
            self.faces = new_faces
        return reduced_vertices, new_faces

In [325]:
# Helper functions
def parse_obj_file(obj_file):
    """
    Parses a .obj file and returns a list of vertices and a face_idicies
    :param obj_file: .obj file to parse
    :return: array of vertices, array of faces
    """
    vertices = []
    faces = []
    with open(obj_file, "r") as f:
        data = f.readlines()
        for line in data:
            tokens = line.split()
            if len(tokens) > 0:
                if tokens[0] == "v":
                    vertex = []
                    vertex.append(float(tokens[1]))
                    vertex.append(float(tokens[2]))
                    vertex.append(float(tokens[3]))
                    vertices.append(vertex)
                elif tokens[0] == "f":
                    vertex_idxs = []
                    for token_idx in range(1, 4):
                        vertex_idx = int(tokens[token_idx].split("/")[0])
                        vertex_idxs.append(vertex_idx)
                    faces.append(vertex_idxs)
                else:
                    continue

    vertices = np.array(vertices)
    faces = np.array(faces)
    faces = faces - 1  # Convert from 1 index to 0 index
    return vertices, faces


def get_edge_go(vertices, faces, color="blue", name="edges"):
    """
    Plots the edges of a mesh
    """

    edges = []
    for face in faces:
        v1 = vertices[face[0]]
        v2 = vertices[face[1]]
        v3 = vertices[face[2]]
        edge_batch = [v1, v2, v3, v1, np.array([None, None, None])]
        edges.extend(edge_batch)

    return go.Scatter3d(
        x=np.array(edges)[:, 0],
        y=np.array(edges)[:, 1],
        z=np.array(edges)[:, 2],
        mode="lines",
        marker=dict(size=2, color=color),
        name=name,
    )


def visualize_mesh(vertices, faces, title=None):
    """
    Visualizes a mesh using plotly
    :param vertices: array of vertices
    :param faces: array of faces
    """
    mesh = go.Mesh3d(
        x=vertices[:, 0],
        y=vertices[:, 1],
        z=vertices[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        color="lightpink",
        opacity=0.50,
    )

    scatter = go.Scatter3d(
        x=vertices[:, 0],
        y=vertices[:, 1],
        z=vertices[:, 2],
        mode="markers",
        marker=dict(size=2, color="blue"),
        name="vertices",
    )

    edges = get_edge_go(vertices, faces, color="blue", name="edges")

    camera = dict(
        up=dict(x=0, y=1, z=0),
        center=dict(x=0, y=0, z=0),
    )

    fig = go.Figure(data=[mesh, scatter, edges])
    title = "Mesh" if title is None else title
    fig.update_layout(scene_camera=camera, title=title)
    fig.show()

In [326]:
vertices, faces = parse_obj_file(MESH_FNAME)
mesh = ModMesh(vertices, faces)
vertices, faces = mesh.get_vertices_and_faces()
visualize_mesh(vertices, faces, "Original Mesh")

### Compute Quadric Error Matrices

In [327]:
def compute_plane(p1, p2, p3):
    """
    Computes the plane defined by three points
    :param p1: first point
    :param p2: second point
    :param p3: third point
    :return: np.array([[a, b, c, d]]).T
    """
    v1 = p2 - p1
    v2 = p3 - p1
    normal = np.cross(v1, v2)
    normal = normal / np.linalg.norm(normal)
    d = -np.dot(normal, p1)
    return np.concatenate((normal, np.array([d]))).reshape((4, 1))


def get_Q_matricies(vertices, faces):
    """
    Computes the Q matrix for each vertex in the mesh
    :param vertices: array of vertices
    :param faces: array of faces
    :return: array of Q matrix for each vertex
    """
    Q_matricies = np.zeros((vertices.shape[0], 4, 4))
    for face_idx in range(faces.shape[0]):
        face = faces[face_idx]
        p = compute_plane(vertices[face[0]], vertices[face[1]], vertices[face[2]])
        K_p = p @ p.T
        for vertex_idx in face:
            Q_matricies[vertex_idx] += K_p

    return Q_matricies

In [328]:
Q_matricies = get_Q_matricies(vertices, faces)
print("First Q Matrix:")
print(Q_matricies[0])

First Q Matrix:
[[2. 0. 0. 2.]
 [0. 2. 0. 2.]
 [0. 0. 2. 2.]
 [2. 2. 2. 6.]]


### Select Valid Pairs

In [329]:
def get_pairs(faces, t):
    """
    Computes the pairs of vertices that are connected by an edge
    :param faces: array of faces
    :param t: threshold for distance between vertices.
    :return: array of vertex pairs

    Note: Set the threshold to 0 for large models.  We currently
    use a O(n^2) algorithm to compute the pairs.  This could probably be faster
    using KD trees or something similar.
    """
    pairs = set()

    # Add edges
    for face_idx in range(faces.shape[0]):
        face = faces[face_idx]
        for combo in [(0, 1), (1, 2), (2, 0)]:
            v1 = face[combo[0]]
            v2 = face[combo[1]]
            if v1 == v2:
                continue
            pair = min(v1, v2), max(v1, v2)
            pairs.add(pair)

    # Add thresholded distances
    if t > 0:
        for vertex_idx in range(vertices.shape[0]):
            for neighbor_idx in range(vertices.shape[0]):
                dist = np.linalg.norm(vertices[vertex_idx] - vertices[neighbor_idx])
                if (vertex_idx != neighbor_idx) and (dist < t):
                    pair = min(vertex_idx, neighbor_idx), max(vertex_idx, neighbor_idx)
                    pairs.add(pair)

    return pairs

In [330]:
threshold = 0.0
pairs = get_pairs(faces, threshold)

### Compute Contraction Targets and Costs

In [331]:
def get_cost(v1, v2, Q1, Q2):
    """
    Computes the cost of contracting v1 and v2
    :param v1: first vertex  (3, )
    :param v2: second vertex  (3, )
    :param Q1: Q matrix for v1
    :param Q2: Q matrix for v2
    :return: v_bar (4, ), cost of contracting v1 and v2 (float)
    """
    Q_bar = Q1 + Q2
    working_Q_bar = Q_bar.copy()
    working_Q_bar[3, :] = 0
    working_Q_bar[3, 3] = 1
    if np.linalg.cond(working_Q_bar) < 1 / np.finfo(float).eps:
        v_bar = np.linalg.inv(working_Q_bar) @ np.array([0, 0, 0, 1]).T
    else:
        print("Singular Q")
        # Find best v by checking endpoints and midpoint if
        # Q_bar is not invertible
        v1 = np.concatenate((v1, np.array([1]))).reshape((4, 1))
        v2 = np.concatenate((v2, np.array([1]))).reshape((4, 1))
        c1 = v1.T @ Q_bar @ v1
        c2 = v2.T @ Q_bar @ v2
        v_mid = (v1 + v2) / 2
        c_bar = v_mid @ Q_bar @ v_mid

        if c_bar < c1 and c_bar < c2:
            v_bar = v_mid
        elif c1 < c2:
            v_bar = v1
        else:
            v_bar = v2

    return v_bar, v_bar.T @ Q_bar @ v_bar


# def find_contraction_pair(pairs, vertices, Q_matricies):
# """
# Finds the best pair of vertices to contract
# :param pairs: array of vertex pairs
# :param vertices: array of vertices
# :param Q_matricies: array of Q matrix for each vertex
# :return: best vertex pair to contract, v_bar, cost of contraction"""
# best_cost = np.inf
# best_v_bar = None
# best_pair = None
# for pair in pairs:
#     v1 = vertices[pair[0]]
#     v2 = vertices[pair[1]]
#     Q1 = Q_matricies[pair[0]]
#     Q2 = Q_matricies[pair[1]]
#     v_bar, cost = get_cost(v1, v2, Q1, Q2)
#     if cost < best_cost:
#         best_cost = cost
#         best_v_bar = v_bar
#         best_pair = pair
# return best_pair, best_v_bar, best_cost


def find_contraction_pair(mesh: ModMesh, Q_matricies):
    """
    Finds the best pair of vertices to contract
    :param pairs: array of vertex pairs
    :param vertices: array of vertices
    :param Q_matricies: array of Q matrix for each vertex
    :return: best vertex pair to contract, v_bar, cost of contraction"""
    best_cost = np.inf
    best_v_bar = None
    best_pair = None
    for pair in pairs:
        v1 = mesh.get_vertex(pair[0])
        v2 = mesh.get_vertex(pair[1])
        Q1 = Q_matricies[pair[0]]
        Q2 = Q_matricies[pair[1]]
        v_bar, cost = get_cost(v1, v2, Q1, Q2)
        if cost < best_cost:
            best_cost = cost
            best_v_bar = v_bar
            best_pair = pair
    return best_pair, best_v_bar, best_cost

In [332]:
vertices, faces = mesh.get_vertices_and_faces()
# best_pair, best_v_bar, best_cost = find_contraction_pair(pairs, vertices, Q_matricies)
best_pair, best_v_bar, best_cost = find_contraction_pair(mesh, Q_matricies)

### Remove the pair with the lowest cost

In [333]:
def remove_pair(mesh: ModMesh, Q_matricies: np.ndarray, pair: tuple, v_bar: np.ndarray):
    """
    Removes a vertex from the mesh
    :param pair: indicies of the pair to remove
    :param mesh: mesh object
    :param Q_matricies: array of Q matrix for each vertex
    :return: new mesh, new Q matrix
    """
    mesh.combine_vertices(pair[0], pair[1], v_bar)
    Q_bar = Q_matricies[pair[0]] + Q_matricies[pair[1]]
    Q_matricies[pair[0]] = Q_bar
    Q_matricies[pair[1]] = Q_bar
    return mesh, Q_matricies


visualize_mesh(vertices, faces, "Original Mesh")
mesh, Q_matricies = remove_pair(mesh, Q_matricies, best_pair, best_v_bar)
pairs.remove(best_pair)

vertices, faces = mesh.get_vertices_and_faces()
visualize_mesh(vertices, faces, "First Contraction")

### Run iteratively until the desired number of vertices is reached

We can put all components of the algorithm together to simplify a more complex mesh.

In [334]:
MESH_FNAME = "../assets/bunny_1k.obj"
# MESH_FNAME = "../assets/cube.obj"
DESIRED_VERTICES = 100
DISPLAY_EVERY = 100

In [337]:
# Load mesh
vertices, faces = parse_obj_file(MESH_FNAME)
mesh = ModMesh(vertices, faces)

# Compute Q matricies
Q_matricies = get_Q_matricies(vertices, faces)

# Get all valid pairs
threshold = 0.01
pairs = get_pairs(faces, threshold)

# Contract mesh until desired number of vertices is reached
iters = 0
visualize_mesh(*mesh.get_vertices_and_faces(), f"Original Mesh: {len(mesh)} vertices")
# Note, sometimes a pair will point to the same vertex.  In that case,
# we don't decrease the number of vertices.
while len(mesh) > DESIRED_VERTICES:
    best_pair, best_v_bar, best_cost = find_contraction_pair(mesh, Q_matricies)
    mesh, Q_matricies = remove_pair(mesh, Q_matricies, best_pair, best_v_bar)
    pairs.remove(best_pair)
    if iters % DISPLAY_EVERY == 0:
        title = f"Iteration {iters}: {len(mesh)} vertices"
        visualize_mesh(*mesh.get_vertices_and_faces(), title)
    iters += 1

visualize_mesh(*mesh.get_vertices_and_faces(), f"Final Mesh: {len(mesh)} vertices")

### Compare original and simplified meshes

Note how the vertices have moved in the simplified mesh while still
retaining the overall form.

In [338]:
original_vertices, original_faces = parse_obj_file(MESH_FNAME)
new_vertices, new_faces = mesh.get_vertices_and_faces(reduce=False)

new_mesh_viz = go.Mesh3d(
    x=new_vertices[:, 0],
    y=new_vertices[:, 1],
    z=new_vertices[:, 2],
    i=new_faces[:, 0],
    j=new_faces[:, 1],
    k=new_faces[:, 2],
    color="lightpink",
    opacity=0.50,
)

old_mesh_viz = go.Mesh3d(
    x=original_vertices[:, 0],
    y=original_vertices[:, 1],
    z=original_vertices[:, 2],
    i=original_faces[:, 0],
    j=original_faces[:, 1],
    k=original_faces[:, 2],
    color="lightblue",
    opacity=0.50,
)

scatter_old = go.Scatter3d(
    x=original_vertices[:, 0],
    y=original_vertices[:, 1],
    z=original_vertices[:, 2],
    mode="markers",
    marker=dict(size=2, color="blue"),
    name="Original Vertices",
)

edges_old = get_edge_go(
    original_vertices, original_faces, color="blue", name="Orignal Edges"
)

scatter_new = go.Scatter3d(
    x=new_vertices[:, 0],
    y=new_vertices[:, 1],
    z=new_vertices[:, 2],
    mode="markers",
    marker=dict(size=2, color="red"),
    name="Simplified Vertices",
)

edges_new = get_edge_go(new_vertices, new_faces, color="red", name="Simplified Edges")

camera = dict(
    up=dict(x=0, y=1, z=0),
    center=dict(x=0, y=0, z=0),
)

fig = go.Figure(
    data=[old_mesh_viz, new_mesh_viz, scatter_old, scatter_new, edges_old, edges_new]
)
fig.update_layout(scene_camera=camera, title="Mesh Comparison")
fig.show()