In [2]:
import open3d as o3d
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
import os

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Load mesh

In [207]:
# Load the meshes
mesh_name = "001201" # must be 4 digits: xx01 indicates lower jaw, xx02 indicates upper jaw
origin_file_path = f"D:\sunny\Codes\DPS\data_teethseg\origin\{mesh_name}_origin.ply"
label_file_path = f"D:\sunny\Codes\DPS\data_teethseg\label\{mesh_name}.ply"
origin_mesh = o3d.io.read_triangle_mesh(origin_file_path)
label_mesh = o3d.io.read_triangle_mesh(label_file_path)

# recentre the meshes
origin_mesh = origin_mesh.translate(-origin_mesh.get_center())
label_mesh = label_mesh.translate(-label_mesh.get_center())



In [208]:
origin_mesh.compute_vertex_normals()
vertices = np.asarray(origin_mesh.vertices)
triangles = np.asarray(origin_mesh.triangles)   

Region Growth

In [209]:
def find_seed_points(mesh): # grow from top (upper face)
    # Get the vertices as a numpy array
    vertices = np.asarray(mesh.vertices)

    # get the index of the vertex with the minimum z value within the range
    x_range = [-0.5, 0.5]
    possible_seed_indices = np.where((vertices[:, 0] > x_range[0]) & (vertices[:, 0] < x_range[1]))[0]
    seed_index = possible_seed_indices[np.argmin(vertices[possible_seed_indices, 2])]

    # get coordinates of the seed point
    seed_point = vertices[seed_index]
    print("Seed point:", seed_point)
    # unlabeled_indices = np.where(seg_labels == -1)[0]
    # seed_index = unlabeled_indices[np.argmin(vertices[unlabeled_indices, 1])]
    
    return seed_index

def create_adjacency_list(mesh):
    adjacency_list = {i: set() for i in range(len(mesh.vertices))}
    triangles = np.asarray(mesh.triangles)
    for triangle in triangles:
        for i, j in zip(triangle, triangle[[1, 2, 0]]):
            adjacency_list[i].add(j)
            adjacency_list[j].add(i)
    return adjacency_list


def region_growing_segmentation(mesh, adjacency_list, seed_index, normal_diff_threshold, normal_y_threshold):
    """
    Perform region growing segmentation on a mesh starting from a seed index using only z-axis distance.

    Parameters:
    - mesh: open3d.geometry.TriangleMesh, the input mesh
    - adjacency_list: dict, adjacency list of vertices
    - seed_index: int, the index of the seed vertex (start from gum)

    - normal_threshold: float, normal dot product threshold for region growing
    - color_threshold: float, color difference threshold for region growing

    Returns:
    - seg_labels: np.ndarray, an array of seg_labels for each vertex in the mesh
    
    Labels:
    - 0: within the upper region 
    - 1: boundary of the upper region 

    - -1: unlabeled, upper face 
    """

    vertices = np.asarray(mesh.vertices)
    normals = np.asarray(mesh.vertex_normals)
    colors = np.asarray(mesh.vertex_colors)

    region=[seed_index]
    seg_labels = np.full(len(vertices), -1, dtype=int) # -1: unlabeled
    print(seg_labels.shape)

    # grow from the upper region
    seg_labels = grow_from_seed(region, vertices, normals, colors, adjacency_list, 
                                seg_labels, normal_diff_threshold, normal_y_threshold)

    return seg_labels

def grow_from_seed(region, vertices, normals, colors, adjacency_list, 
                   seg_labels, normal_diff_threshold, normal_y_threshold):
    while region:
        current_index = region.pop()
        current_normal = normals[current_index]

        
        for neighbor_index in adjacency_list[current_index]:
            if seg_labels[neighbor_index] == -1:
                neighbor_vertex = vertices[neighbor_index]
                neighbor_normal = normals[neighbor_index]
                neighbor_color = colors[neighbor_index]
                
                normal_dot_y = np.dot(current_normal[1], neighbor_normal[1]) # track normal change in y-axis
                print("norm_dot:", normal_dot_y)
                print("neighbor_y:", neighbor_vertex[1])
                
                if normal_dot_y > normal_diff_threshold and neighbor_vertex[1] > normal_y_threshold: # normal change in y-axis and within the upper region
                    seg_labels[neighbor_index] = 0 # within the upper region
                    region.append(neighbor_index)
                else:
                    seg_labels[neighbor_index] = 1 # boundary
    
    return seg_labels


In [210]:
# Visualize the segmentation
def display_region_growth_outcome(mesh, seg_labels, seed_index):
    # Duplicate the original mesh
    mesh_copy = copy.deepcopy(mesh)
    
    # Update the vertex colors based on labels for visualization
    colors = np.asarray(mesh_copy.vertex_colors)
    vertices = np.asarray(mesh_copy.vertices)
    for i in range(len(seg_labels)):
        if seg_labels[i] == 0:
            colors[i] = [0, 1, 0]  # Green for outer region
        elif seg_labels[i] == 1:
            colors[i] = [0, 0, 1]  # Blue for inner region
        elif seg_labels[i] == 2:
            colors[i] = [1, 1, 0]  # yellow for outer boundaries
        elif seg_labels[i] == 3:
            colors[i] = [1, 0, 1] # purple for inner boundaries
        # Else keep the original color for label -1 (default color)
        else:
            assert seg_labels[i] == -1, "Invalid label value"
    
    # draw a sphere at the seed point
    seed_point = vertices[seed_index]
    sphere = o3d.geometry.TriangleMesh.create_sphere(radius=2)
    sphere = sphere.translate(seed_point)


    # Assign the updated colors back to the mesh
    mesh_copy.vertex_colors = o3d.utility.Vector3dVector(colors)
    
    # coordinates axes
    axes = o3d.geometry.TriangleMesh.create_coordinate_frame(size=20, origin=[0, 0, 0])
    # Visualize the segmented mesh with color labels
    o3d.visualization.draw_geometries([mesh_copy, sphere, axes])

In [211]:
adjacency_list = create_adjacency_list(origin_mesh)

# Initialize seg_labels array    
seed_index = find_seed_points(origin_mesh)

# Set the threshold values
normal_diff_threshold = 0.12
normal_y_threshold = -1.36

seg_labels = region_growing_segmentation(origin_mesh, adjacency_list, seed_index, normal_diff_threshold=normal_diff_threshold, normal_y_threshold=normal_y_threshold)

Seed point: [ 0.1804302  -4.25467998  5.75339616]
(20210,)
norm_dot: 0.9366416725968975
neighbor_y: -4.200435285989574
norm_dot: 0.9161926746426715
neighbor_y: -4.239801470157595


In [212]:
print("unlabeled:")
print(np.where(seg_labels == -1)[0].shape)
print("within upper region:")
print(np.where(seg_labels == 0)[0].shape)

print("on upper boundary:")
print(np.where(seg_labels == 1)[0].shape)


display_region_growth_outcome(origin_mesh, seg_labels, seed_index=seed_index)

unlabeled:
(20208,)
within upper region:
(0,)
on upper boundary:
(2,)
