In [2]:
import open3d as o3d
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
import os

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Load mesh

In [5]:
# Load the meshes
mesh_name = "000101" # must be 4 digits: xx01 indicates lower jaw, xx02 indicates upper jaw
origin_file_path = f"D:\sunny\Codes\DPS\data_teethseg\origin\{mesh_name}_origin.ply"
label_file_path = f"D:\sunny\Codes\DPS\data_teethseg\label\{mesh_name}.ply"
origin_mesh = o3d.io.read_triangle_mesh(origin_file_path)
label_mesh = o3d.io.read_triangle_mesh(label_file_path)



In [6]:
origin_mesh.compute_vertex_normals()
vertices = np.asarray(origin_mesh.vertices)
triangles = np.asarray(origin_mesh.triangles)   

Region Growth

In [63]:
def find_seed_points(mesh): # grow from front (outer) or back (inner)
    # Get the vertices as a numpy array
    vertices = np.asarray(mesh.vertices)

    # Find the index of the vertex with the extreme z values within small-range x values
    x_range = [-0.1, 0.1]
    # get the indices of vertices within the x_range
    possible_seed_indices = np.where((vertices[:, 0] > x_range[0]) & (vertices[:, 0] < x_range[1]))[0]

    # get the index of the vertex with the minimum and maximum z value
    inner_seed_index = possible_seed_indices[np.argmin(vertices[possible_seed_indices, 2])]
    outer_seed_index = possible_seed_indices[np.argmax(vertices[possible_seed_indices, 2])]
    # unlabeled_indices = np.where(seg_labels == -1)[0]
    # seed_index = unlabeled_indices[np.argmin(vertices[unlabeled_indices, 1])]
    
    return [inner_seed_index, outer_seed_index]

def create_adjacency_list(mesh):
    adjacency_list = {i: set() for i in range(len(mesh.vertices))}
    triangles = np.asarray(mesh.triangles)
    for triangle in triangles:
        for i, j in zip(triangle, triangle[[1, 2, 0]]):
            adjacency_list[i].add(j)
            adjacency_list[j].add(i)
    return adjacency_list


def region_growing_segmentation(mesh, adjacency_list, seed_indices, normal_threshold=0.99, color_threshold=0.1):
    """
    Perform region growing segmentation on a mesh starting from a seed index using only z-axis distance.

    Parameters:
    - mesh: open3d.geometry.TriangleMesh, the input mesh
    - adjacency_list: dict, adjacency list of vertices
    - seed_index: int, the index of the seed vertex (start from gum)

    - normal_threshold: float, normal dot product threshold for region growing
    - color_threshold: float, color difference threshold for region growing

    Returns:
    - seg_labels: np.ndarray, an array of seg_labels for each vertex in the mesh
    
    Labels:
    - 0: within the outer region (even: out)
    - 1: within the inner region (odd: in)
    - 2: boundary of the outer region 
    - 3: boundary of the inner region

    - -1: unlabeled, upper face 
    """

    vertices = np.asarray(mesh.vertices)
    normals = np.asarray(mesh.vertex_normals)
    colors = np.asarray(mesh.vertex_colors)

    seed_idx_out, seed_idx_in = seed_indices
    region_out = [seed_idx_out]
    region_in = [seed_idx_in]
    seg_labels = np.full(len(vertices), -1, dtype=int) # -1: unlabeled
    print(seg_labels.shape)
    seg_labels[seed_idx_out] = 0
    seg_labels[seed_idx_in] = 1

    # grow from the outer region
    seg_labels = grow_from_seed(region_out, vertices, normals, colors, adjacency_list, 
                                seg_labels, normal_threshold, color_threshold, is_outer_region=True)
    # grow from the inner region
    seg_labels = grow_from_seed(region_in, vertices, normals, colors, adjacency_list,
                                seg_labels, normal_threshold, color_threshold, is_outer_region=False)
    return seg_labels

def grow_from_seed(region, vertices, normals, colors, adjacency_list, 
                   seg_labels, normal_threshold, color_threshold,
                   is_outer_region):
    if is_outer_region:
        lbs = [0, 2]
    else:
        lbs = [1, 3]

    while region:
        current_index = region.pop()
        current_vertex = vertices[current_index]
        current_normal = normals[current_index]
        current_color = colors[current_index]
        
        for neighbor_index in adjacency_list[current_index]:
            if seg_labels[neighbor_index] == -1:
                neighbor_vertex = vertices[neighbor_index]
                neighbor_normal = normals[neighbor_index]
                neighbor_color = colors[neighbor_index]
                
                normal_dot_y = np.dot(current_normal[1], neighbor_normal[1]) # track normal change in y-axis
                print(normal_dot_y)
                color_diff = np.linalg.norm(current_color - neighbor_color)
                
                if normal_dot_y > normal_threshold:
                    seg_labels[neighbor_index] = lbs[0]
                    region.append(neighbor_index)
                else:
                    seg_labels[neighbor_index] = lbs[1]
    
    return seg_labels


In [78]:
adjacency_list = create_adjacency_list(origin_mesh)

# Initialize seg_labels array    
seed_indices = find_seed_points(origin_mesh)

# Set the threshold values
normal_threshold = 0.01
color_threshold = 0.1


seg_labels = region_growing_segmentation(origin_mesh, adjacency_list, seed_indices, normal_threshold=normal_threshold, color_threshold=color_threshold)

(20615,)
0.9505715771200121
0.9613080139099138
0.9776406519814831
0.9302949526570503
0.951672618239109
0.9247447334924995
0.9085655691140754
0.8628428815226401
0.8701296897432951
0.8739337332087714
0.875378493101553
0.8513158381913729
0.8308613451110658
0.8343160860558514
0.7797895270810773
0.811385572685889
0.8131686598195629
0.8375130833732158
0.8395022378010857
0.8768005255537551
0.920288579302435
0.9138137336655252
0.9249929899551498
0.953156581075064
0.9745988121840218
0.9643666482146703
0.9811512723052888
0.9806904081931604
0.9772641865420252
0.9857410048725974
0.9879824146889351
0.9960618318763799
0.8751259328703055
0.9760032219752799
0.9839323719845816
0.891456835721701
0.796032586271161
0.8068914259902443
0.6474723442319126
0.5785311489980366
0.3885383005884865
0.3574459966333156
0.293870258916762
0.34742571367337544
0.36885913259069203
0.5280720480522245
0.5258381493763616
0.6497581965257742
0.3666374949770049
0.6867706570044032
0.509889900437744
0.7893918885606778
0.74093380

In [79]:
print(np.where(seg_labels == -1)[0].shape)
print("outer:")
print(np.where(seg_labels == 0)[0].shape)
print(np.where(seg_labels == 2)[0].shape)

print("inner:")
print(np.where(seg_labels == 1)[0].shape)
print(np.where(seg_labels == 3)[0].shape)



(3131,)
outer:
(16667,)
(812,)
inner:
(1,)
(4,)


In [80]:
# Visualize the segmentation
def display_region_growth_outcome(mesh, seg_labels):
    # Duplicate the original mesh
    mesh_copy = copy.deepcopy(mesh)
    
    # Update the vertex colors based on labels for visualization
    colors = np.asarray(mesh_copy.vertex_colors)
    for i in range(len(seg_labels)):
        if seg_labels[i] == 0:
            colors[i] = [0, 1, 0]  # Green for outer region
        elif seg_labels[i] == 1:
            colors[i] = [0, 0, 1]  # Blue for inner region
        elif seg_labels[i] == 2:
            colors[i] = [1, 1, 0]  # yellow for outer boundaries
        elif seg_labels[i] == 3:
            colors[i] = [1, 0, 1] # purple for inner boundaries
        # Else keep the original color for label -1 (default color)
        else:
            assert seg_labels[i] == -1, "Invalid label value"

    # Assign the updated colors back to the mesh
    mesh_copy.vertex_colors = o3d.utility.Vector3dVector(colors)
    
    # coordinates axes
    axes = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
    # Visualize the segmented mesh with color labels
    o3d.visualization.draw_geometries([mesh_copy, axes])

In [81]:
display_region_growth_outcome(origin_mesh, seg_labels)