Packages:
$ pip install pymeshlab

In [1]:
import open3d as o3d
import numpy as np
import os
import shutil

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [5]:
def find_seed_point(mesh, jaw_type, labels, round_num): # grow from end boundary of gum
    # Get the vertices as a numpy array
    vertices = np.asarray(mesh.vertices)
    # Find the index of the vertex with:
    # - the min y-coordinate for "01" lower jaw
    # - the max y-coordinate for "02" upper jaw
    if jaw_type == "01":
        if round_num == 1:
            seed_index = np.argmin(vertices[:, 1])
        elif round_num == 2:
            # Find the index of the unlabeled vertex with the maximum y-coordinate for lower jaw
            unlabeled_indices = np.where(labels == -1)[0]
            seed_index = unlabeled_indices[np.argmin(vertices[unlabeled_indices, 1])]
        else:
            raise ValueError("round_num must be 1 or 2.")
    
    
    elif jaw_type == "02":
        if round_num == 1:
            seed_index = np.argmax(vertices[:, 1])
        elif round_num == 2:
            # Find the index of the unlabeled vertex with the minimum y-coordinate for upper jaw
            unlabeled_indices = np.where(labels == -1)[0]
            seed_index = unlabeled_indices[np.argmax(vertices[unlabeled_indices, 1])]
        else:
            raise ValueError("round_num must be 1 or 2.")
    else:
        raise ValueError("jaw_type must be '01' for lower jaw or '02' for upper jaw.")
    return seed_index

In [6]:
def create_adjacency_list(mesh):
    adjacency_list = {i: set() for i in range(len(mesh.vertices))}
    triangles = np.asarray(mesh.triangles)
    for triangle in triangles:
        for i, j in zip(triangle, triangle[[1, 2, 0]]):
            adjacency_list[i].add(j)
            adjacency_list[j].add(i)
    return adjacency_list

In [7]:
def region_growing_segmentation(mesh, adjacency_list, seed_index, labels, y_threshold=0.02, normal_threshold=0.9, color_threshold=0.1):
    """
    Perform region growing segmentation on a mesh starting from a seed index using only z-axis distance.

    Parameters:
    - mesh: open3d.geometry.TriangleMesh, the input mesh
    - adjacency_list: dict, adjacency list of vertices
    - seed_index: int, the index of the seed vertex (start from gum)
    - y_threshold: float, y-axis distance threshold for region growing
    - normal_threshold: float, normal dot product threshold for region growing
    - color_threshold: float, color difference threshold for region growing

    Returns:
    - labels: np.ndarray, an array of labels for each vertex in the mesh
    
    Labels:
    - 1: within the gum region
    - 0: gum-boundary region
    - -1: unlabeled, outside the gum region (i.e. teeth region)
    """

    

    vertices = np.asarray(mesh.vertices)
    normals = np.asarray(mesh.vertex_normals)
    colors = np.asarray(mesh.vertex_colors)

    
    region = [seed_index]
    labels[seed_index] = 1

    while region:
        current_index = region.pop()
        current_vertex = vertices[current_index]
        current_normal = normals[current_index]
        current_color = colors[current_index]
        
        for neighbor_index in adjacency_list[current_index]:
            if labels[neighbor_index] == -1:
                neighbor_vertex = vertices[neighbor_index]
                neighbor_normal = normals[neighbor_index]
                neighbor_color = colors[neighbor_index]
                
                y_distance = abs(current_vertex[2] - neighbor_vertex[2])
                normal_dot = np.dot(current_normal, neighbor_normal)
                color_diff = np.linalg.norm(current_color - neighbor_color)
                
                if y_distance < y_threshold and normal_dot > normal_threshold and color_diff < color_threshold:
                    labels[neighbor_index] = 0
                    region.append(neighbor_index)
                else:
                    labels[neighbor_index] = 1
    
    return labels

In [8]:
def process_all_files(input_dir, output_dir, y_threshold, normal_threshold, color_threshold):
    
    # Iterate over all files in the input directory
    for file_name in os.listdir(input_dir):
        if file_name.endswith(".ply"):
            input_file_path = os.path.join(input_dir, file_name)
            jaw_type = file_name[4:6]
            base_name = file_name.split(".")[0]
            output_file_path = os.path.join(output_dir, f"{base_name}_seg.ply")


            # Load the mesh
            mesh = o3d.io.read_triangle_mesh(input_file_path)
            if mesh.is_empty():
                print(f"Failed to load the mesh file: {input_file_path}")
                continue
            mesh.compute_vertex_normals()

            vertices = np.asarray(mesh.vertices)
            triangles = np.asarray(mesh.triangles)

            # Create adjacency list for vertices
            adjacency_list = create_adjacency_list(mesh)
            
            # Initialize labels array
            labels = np.full(len(vertices), -1, dtype=int)
            
            # First round of region growing
            seed_index1 = find_seed_point(mesh, jaw_type, labels, round_num=1)
            print("Seed index1: ", seed_index1)
            labels = region_growing_segmentation(mesh, adjacency_list, seed_index1, labels, y_threshold, normal_threshold, color_threshold)

            # Second round of region growing with a new seed from unlabeled points
            seed_index2 = find_seed_point(mesh, jaw_type, labels, round_num=2)
            print("Seed index2: ", seed_index2)
            labels = region_growing_segmentation(mesh, adjacency_list, seed_index2, labels, y_threshold, normal_threshold, color_threshold)

            # # Debug: Check the labels array
            # print(f"Labels array {file_name}: {labels}")
            
            # Check if the labels contain any segmented vertices
            if len(vertices[labels == 1]) == 0 or len(triangles[np.all(labels[triangles] == 1, axis=1)]) == 0:
                raise ValueError("The resulting gum mesh is empty. Check the segmentation criteria and thresholds.")
            
            
            
            # Extract triangles of the teeth region (all vertices of label = -1)
            teeth_triangles = triangles[np.all(labels[triangles] == -1, axis=1)]
            if len(teeth_triangles) == 0:
                raise ValueError("No teeth region found.")
            # Duplicate the original mesh
            shutil.copyfile(input_file_path, output_file_path)
            teeth_mesh = o3d.io.read_triangle_mesh(output_file_path)
            teeth_mesh.triangles = o3d.utility.Vector3iVector(teeth_triangles)
            

            # Preserve the largest connected component of the teeth region
            # Remove disconnected small component pieces
            comps_label = np.array(teeth_mesh.cluster_connected_triangles()[0])
            assert len(comps_label) == len(teeth_triangles), "Mismatch between number of triangles and size of comps_label array."
            maxcomp = np.argmax(np.bincount(comps_label)) # largest connected component
            mask_maxcomp = (comps_label == maxcomp)
            teeth_triangles = teeth_triangles[mask_maxcomp]
            teeth_mesh.triangles = o3d.utility.Vector3iVector(teeth_triangles)
            o3d.io.write_triangle_mesh(output_file_path, teeth_mesh)
                        
                        
            print(f"Processed {file_name} and saved the largest component to {output_file_path}")


In [9]:
# TODO: Set input values
input_dir = "E:\\OneDrive\\OneDrive - University of Cambridge\\Documents\\Coding\\DPS_hku\\data_new_65536\\Origin"
output_dir = "E:\OneDrive\OneDrive - University of Cambridge\Documents\Coding\DPS_hku\gum_removal\segment\origin"

y_threshold = 11.0
normal_threshold = 0.986 # lower
# normal_threshold = 0.985 # upper
color_threshold = 0.05

In [10]:
process_all_files(input_dir, output_dir, y_threshold, normal_threshold, color_threshold)

Seed index1:  29148
Seed index2:  14629
Processed 000101_origin.ply and saved the largest component to E:\OneDrive\OneDrive - University of Cambridge\Documents\Coding\DPS_hku\gum_removal\segment\origin\000101_origin.ply_seg.ply
Seed index1:  2103
Seed index2:  3954
Processed 000102_origin.ply and saved the largest component to E:\OneDrive\OneDrive - University of Cambridge\Documents\Coding\DPS_hku\gum_removal\segment\origin\000102_origin.ply_seg.ply
Seed index1:  20602
Seed index2:  20217
Processed 000201_origin.ply and saved the largest component to E:\OneDrive\OneDrive - University of Cambridge\Documents\Coding\DPS_hku\gum_removal\segment\origin\000201_origin.ply_seg.ply
Seed index1:  33467
Seed index2:  14966
Processed 000202_origin.ply and saved the largest component to E:\OneDrive\OneDrive - University of Cambridge\Documents\Coding\DPS_hku\gum_removal\segment\origin\000202_origin.ply_seg.ply
Seed index1:  23253
Seed index2:  22412
Processed 000301_origin.ply and saved the largest 

ValueError: The resulting gum mesh is empty. Check the segmentation criteria and thresholds.