In [1]:
import sys
import os
from trimesh import PointCloud
sys.path.append(os.getcwd())
from glob import glob
# import gen_utils as gu
import numpy as np
import open3d as o3d
from sklearn.neighbors import KDTree
import copy
import argparse
import json
import trimesh
from stl import mesh as stlmesh

# change path to the mesh and prediction json file
mesh_path =  "../samples/SAMPLE1/SAMPLE1_upper.stl"
pred_json_path = "../results/SAMPLE1_upper.json"
colored_mesh_save_path = "../results/SAMPLE1_upper_colored.obj" # need to be .obj format
individual_tooth_save_dir = "../results/individual_upper" # save individual tooth mesh

# helper functions

def cal_metric(gt_labels, pred_sem_labels, pred_ins_labels, is_half=None, vertices=None):
    ins_label_names = np.unique(pred_ins_labels)
    ins_label_names = ins_label_names[ins_label_names != 0]
    IOU = 0
    F1 = 0
    ACC = 0
    SEM_ACC = 0
    IOU_arr = []
    for ins_label_name in ins_label_names:
        #instance iou
        ins_label_name = int(ins_label_name)
        ins_mask = pred_ins_labels==ins_label_name
        gt_label_uniqs, gt_label_counts = np.unique(gt_labels[ins_mask], return_counts=True)
        gt_label_name = gt_label_uniqs[np.argmax(gt_label_counts)]
        gt_mask = gt_labels == gt_label_name

        TP = np.count_nonzero(gt_mask * ins_mask)
        FN = np.count_nonzero(gt_mask * np.invert(ins_mask))
        FP = np.count_nonzero(np.invert(gt_mask) * ins_mask)
        TN = np.count_nonzero(np.invert(gt_mask) * np.invert(ins_mask))

        ACC += (TP + TN) / (FP + TP + FN + TN)
        precision = TP / (TP+FP)
        recall = TP / (TP+FN)
        F1 += 2*(precision*recall) / (precision + recall)
        IOU += TP / (FP+TP+FN)
        IOU_arr.append(TP / (FP+TP+FN))
        #segmentation accuracy
        pred_sem_label_uniqs, pred_sem_label_counts = np.unique(pred_sem_labels[ins_mask], return_counts=True)
        sem_label_name = pred_sem_label_uniqs[np.argmax(pred_sem_label_counts)]
        if is_half:
            if sem_label_name == gt_label_name or sem_label_name + 8 == gt_label_name:
                SEM_ACC +=1
        else:
            if sem_label_name == gt_label_name:
                SEM_ACC +=1
        #print("gt is", gt_label_name, "pred is", sem_label_name, sem_label_name == gt_label_name)
    return IOU/len(ins_label_names), F1/len(ins_label_names), ACC/len(ins_label_names), SEM_ACC/len(ins_label_names), IOU_arr

def np_to_pcd(arr, color=[1,0,0]):
    arr = np.array(arr)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(arr[:,:3])
    if arr.shape[1] >= 6:
        pcd.normals = o3d.utility.Vector3dVector(arr[:,3:6])
    pcd.colors = o3d.utility.Vector3dVector([color]*len(pcd.points))
    return pcd


def print_3d(*data_3d_ls):
    data_3d_ls = [item for item in data_3d_ls]
    for idx, item in enumerate(data_3d_ls):
        if type(item) == np.ndarray:
            data_3d_ls[idx] = np_to_pcd(item)
    o3d.visualization.draw_geometries(data_3d_ls, mesh_show_wireframe = True, mesh_show_back_face = True)

def load_json(file_path):
    with open(file_path, "r") as st_json:
        return json.load(st_json)

def read_txt_obj_ls(path, ret_mesh=False, use_tri_mesh=False):
    # use_tri_mesh when the mesh is in .stl format 
    # In some cases, trimesh can change vertex order
    if use_tri_mesh:
        mesh = o3d.io.read_triangle_mesh(path)
    else:
        f = open(path, 'r')
        vertex_ls = []
        tri_ls = []
        #vertex_color_ls = []
        while True:
            line = f.readline().split()
            if not line: break
            if line[0]=='v':
                vertex_ls.append(list(map(float,line[1:4])))
                #vertex_color_ls.append(list(map(float,line[4:7])))
            elif line[0]=='f':
                tri_verts_idxes = list(map(str,line[1:4]))
                if "//" in tri_verts_idxes[0]:
                    for i in range(len(tri_verts_idxes)):
                        tri_verts_idxes[i] = tri_verts_idxes[i].split("//")[0]
                tri_verts_idxes = list(map(int, tri_verts_idxes))
                tri_ls.append(tri_verts_idxes)
            else:
                continue
        f.close()

        mesh = o3d.geometry.TriangleMesh()
        mesh.vertices = o3d.utility.Vector3dVector(vertex_ls)
        mesh.triangles = o3d.utility.Vector3iVector(np.array(tri_ls)-1)

    mesh.compute_vertex_normals()

    norms = np.array(mesh.vertex_normals)

    # vertex_ls = np.array(vertex_ls)
    vertex_ls = np.asarray(mesh.vertices)
    # print("vertex_ls", vertex_ls)
    output = [np.concatenate([vertex_ls,norms], axis=1)]

    if ret_mesh:
        output.append(mesh)
    return output

def get_colored_mesh(mesh, label_arr):
    palte = np.array([
        [255, 255, 255],  # White
        [255, 153, 153],  # Light Red
        [153, 76, 0],     # Brown
        [153, 153, 0],    # Olive
        [76, 153, 0],     # Dark Green
        [0, 153, 153],    # Teal
        [0, 0, 153],      # Navy Blue
        [153, 0, 153],    # Purple
        [153, 0, 76],     # Dark Pink
        [64, 64, 0],      # Olive Drab
        [255, 128, 0],    # Orange
        [255, 0, 0],      # Red
        [0, 255, 0],      # Green
        [0, 0, 255],      # Blue
        [255, 255, 0],    # Yellow
        [255, 0, 255],    # Magenta
        [0, 255, 255],    # Cyan
        [64, 64, 64],     # Gray
    ])/255
    # palte[9:] *= 0.4
    label_arr = label_arr.copy()
    # label_arr %= palte.shape[0]
    label_colors = np.zeros((label_arr.shape[0], 3))
    for idx, lbl in enumerate(np.sort(np.unique(label_arr))):
        label_colors[label_arr==lbl] = palte[idx]
    mesh.vertex_colors = o3d.utility.Vector3dVector(label_colors)
    return mesh

def get_mesh_of_each_tooth(mesh, label_arr, label):
    # Filter vertices
    vertices = np.asarray(mesh.vertices)
    faces = np.asarray(mesh.triangles)
    vertex_indices = np.where(label_arr == label)[0]
    
    # Create a mask for faces that are composed entirely of the filtered vertices
    face_mask = np.all(np.isin(faces, vertex_indices), axis=1)
    filtered_faces = faces[face_mask]
    
    
    # Map the vertex indices to the new mesh
    unique_vertex_indices, new_faces = np.unique(filtered_faces, return_inverse=True)
    new_vertices = vertices[unique_vertex_indices]
    new_faces = new_faces.reshape(filtered_faces.shape)
    print(new_vertices.shape)
    
    # Create a new mesh
    new_mesh = o3d.geometry.TriangleMesh()
    new_mesh.vertices = o3d.utility.Vector3dVector(new_vertices)
    new_mesh.triangles = o3d.utility.Vector3iVector(new_faces)

    new_mesh.compute_vertex_normals()
    
    return new_mesh


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# load labels
pred_loaded_json = load_json(pred_json_path)
pred_labels = np.array(pred_loaded_json['labels']).reshape(-1)

# load mesh
_, mesh = read_txt_obj_ls(mesh_path, ret_mesh=True, use_tri_mesh=True)
mesh = mesh.remove_duplicated_vertices()

# create colored segmented mesh
cl_mesh = get_colored_mesh(mesh, pred_labels)
o3d.io.write_triangle_mesh(colored_mesh_save_path, cl_mesh)



True

In [3]:
for i in np.unique(pred_labels):
    new_mesh = get_mesh_of_each_tooth(mesh, pred_labels, i)
    o3d.io.write_triangle_mesh(f"{individual_tooth_save_dir}/tooth_{i}.stl", new_mesh) # tooth_0 is the gum

(95148, 3)
(9863, 3)
(5310, 3)
(6208, 3)
(7702, 3)
(5874, 3)
(10082, 3)
(8789, 3)
(7179, 3)
(3787, 3)
(5605, 3)
(6539, 3)
(5278, 3)
(9551, 3)
(8608, 3)
