In [2]:
import os
import trimesh
import numpy as np
import json
from collections import defaultdict, Counter
from tqdm import tqdm

In [3]:
FDI2label = {
             0: 0,  # gum
             21: 1, 22: 2, 23: 3, 24: 4, 25: 5, 26: 6, 27: 7, 28: 8, # upper left
             11: 9, 12: 10, 13: 11, 14: 12, 15: 13, 16: 14, 17: 15, 18: 16, # upper right
             
             31: 1, 32: 2, 33: 3, 34: 4, 35: 5, 36: 6, 37: 7, 38: 8, # lower left
             41: 9, 42: 10, 43: 11, 44: 12, 45: 13, 46: 14, 47: 15, 48: 16} # lower right

In [4]:
color2label = {
    # upper label 1-8 UL1-8, label 9-16 UR1-8
    (170, 255, 127): ("aaff7f", "UL1", 1),
    (170, 255, 255): ("aaffff", "UL2", 2),
    (255, 255, 0): ("ffff00", "UL3", 3),
    (255, 170, 0): ("ffaa00", "UL4", 4),
    (170, 170, 255): ("aaaaff", "UL5", 5),
    (0, 170, 255): ("00aaff", "UL6", 6),
    (85, 170, 0): ("55aa00", "UL7", 7),
    (204, 204, 15): ("cccc0f", "UL8", 8),

    (255, 85, 255): ("ff55ff", "UR1", 9),
    (255, 85, 127): ("ff557f", "UR2", 10),
    (85, 170, 127): ("55aa7f", "UR3", 11),
    (255, 85, 0): ("ff5500", "UR4", 12),
    (0, 85, 255): ("0055ff", "UR5", 13),
    (170, 0, 0): ("aa0000", "UR6", 14),
    (73, 247, 235): ("49f7eb", "UR7", 15),
    (125, 18, 247): ("7d12f7", "UR8", 16),

    # lower 1-8 LL1-8, 9-16 LR1-8
    (240, 0, 0): ("f00000", "LL1", 1),
    (251, 255, 3): ("fbff03", "LL2", 2),
    (44, 251, 255): ("2cfbff", "LL3", 3),
    (241, 47, 255): ("f12fff", "LL4", 4),
    (125, 255, 155): ("7dff9b", "LL5", 5),
    (26, 125, 255): ("1a7dff", "LL6", 6),
    (255, 234, 157): ("ffea9d", "LL7", 7),
    (204, 126, 126): ("cc7e7e", "LL8", 8),

    (206, 129, 212): ("ce81d4", "LR1", 9),
    (45, 135, 66): ("2d8742", "LR2", 10),
    (185, 207, 45): ("b9cf2d", "LR3", 11),
    (69, 147, 207): ("4593cf", "LR4", 12),
    (207, 72, 104): ("cf4868", "LR5", 13),
    (4, 207, 4): ("04cf04", "LR6", 14),
    (35, 1, 207): ("2301cf", "LR7", 15),
    (82, 204, 169): ("52cca9", "LR8", 16),

    # gum
    (125, 125, 125): ("7d7d7d", 'GUM', 0),
}

In [5]:
label2color_lower = {
    1: ("f00000", "LL1", (240, 0, 0)),
    2: ("fbff03", "LL2", (251, 255, 3)),
    3: ("2cfbff", "LL3", (44, 251, 255)),
    4: ("f12fff", "LL4", (241, 47, 255)),
    5: ("7dff9b", "LL5", (125, 255, 155)),
    6: ("1a7dff", "LL6", (26, 125, 255)),
    7: ("ffea9d", "LL7", (255, 234, 157)),
    8: ("cc7e7e", "LL8", (204, 126, 126)),

    9: ("ce81d4", "LR1", (206, 129, 212)),
    10: ("2d8742", "LR2", (45, 135, 66)),
    11: ("b9cf2d", "LR3", (185, 207, 45)),
    12: ("4593cf", "LR4", (69, 147, 207)),
    13: ("cf4868", "LR5", (207, 72, 104)),
    14: ("04cf04", "LR6", (4, 207, 4)),
    15: ("2301cf", "LR7", (35, 1, 207)),
    16: ("52cca9", "LR8", (82, 204, 169)),

    # gum
    0: ("7d7d7d", 'GUM', (125, 125, 125)),
}

label2color_upper = {
    1: ("aaff7f", "UL1", (170, 255, 127)),
    2: ("aaffff", "UL2", (170, 255, 255)),
    3: ("ffff00", "UL3", (255, 255, 0)),
    4: ("faa00", "UL4", (255, 170, 0)),
    5: ("aaaaff", "UL5", (170, 170, 255)),
    6: ("00aaff", "UL6", (0, 170, 255)),
    7: ("55aa00", "UL7", (85, 170, 0)),
    8: ("cccc0f", "UL8", (204, 204, 15)),

    9: ("ff55ff", "UR1", (255, 85, 255)),
    10: ("ff557f", "UR2", (255, 85, 127)),
    11: ("55aa7f", "UR3", (85, 170, 127)),
    12: ("ff5500", "UR4", (255, 85, 0)),
    13: ("0055ff", "UR5", (0, 85, 255)),
    14: ("aa0000", "UR6", (170, 0, 0)),
    15: ("49f7eb", "UR7", (73, 247, 235)),
    16: ("7d12f7", "UR8", (125, 18, 247)),

    # gum
    0: ("7d7d7d", 'GUM', (125, 125, 125)),
}

In [6]:
def face_labels_to_vertex_labels(faces, face_labels, num_vertices):
    # faces: (F, 3) int array of vertex indices per face
    # face_labels: (F,) int array of face labels
    # num_vertices: 顶点总数

    vertex_face_labels = defaultdict(list)

    for face_idx, face in enumerate(faces):
        label = face_labels[face_idx]
        for v in face:
            vertex_face_labels[v].append(label)

    vertex_labels = np.zeros(num_vertices, dtype=face_labels.dtype)
    for v in range(num_vertices):
        if v in vertex_face_labels:
            # 统计邻接面标签的众数
            c = Counter(vertex_face_labels[v])
            vertex_labels[v] = c.most_common(1)[0][0]
        else:
            vertex_labels[v] = -1  # 或者设置为无标签标识

    return vertex_labels

In [7]:
def output_pred_ply(pred_mask, cell_coords, path, point_coords=None, face_info=None, vertex_colors=None):
    if point_coords is not None and face_info is not None:
        vertex_info = ""
        cell_info = ""
        for idx, pc in enumerate(point_coords):
            if vertex_colors is None:
                vertex_info += f'{pc[0]} {pc[1]} {pc[2]} {125} {125} {125} {255}\n'
            else:
                vertex_info += f'{pc[0]} {pc[1]} {pc[2]} {vertex_colors[idx][0]} {vertex_colors[idx][1]} {vertex_colors[idx][2]} {255}\n'

        valid_face_num = 0
        for color, fi in zip(pred_mask, face_info):
            if fi[0] == 0 and fi[1] == 0 and fi[2] == 0:
                continue
            cell_info += f'3 {int(fi[0])} {int(fi[1])} {int(fi[2])} {color[0]} {color[1]} {color[2]} {255}\n'
            valid_face_num += 1
        header = (f"ply\n"
                  f"format ascii 1.0\n"
                  f"comment VCGLIB generated\n"
                  f"element vertex {point_coords.shape[0]}\n"
                  f"property double x\n"
                  f"property double y\n"
                  f"property double z\n"
                  f"property uchar red\n"
                  f"property uchar green\n"
                  f"property uchar blue\n"
                  f"property uchar alpha\n"
                  f"element face {valid_face_num}\n"
                  f"property list uchar int vertex_indices\n"
                  f"property uchar red\n"
                  f"property uchar green\n"
                  f"property uchar blue\n"
                  f"property uchar alpha\n"
                  f"end_header\n")
    else:
        header = (f"ply\n"
                  f"format ascii 1.0\n"
                  f"comment VCGLIB generated\n"
                  f"element vertex {cell_coords.shape[0]}\n"
                  f"property double x\n"
                  f"property double y\n"
                  f"property double z\n"
                  f"property uchar red\n"
                  f"property uchar green\n"
                  f"property uchar blue\n"
                  f"property uchar alpha\n"
                  f"element face {0}\n"
                  f"property list uchar int vertex_indices\n"
                  f"property uchar red\n"
                  f"property uchar green\n"
                  f"property uchar blue\n"
                  f"property uchar alpha\n"
                  f"end_header\n")

        vertex_info = ""
        cell_info = ""
        for color, coord in zip(pred_mask, cell_coords):
            vertex_info += f'{coord[0]} {coord[1]} {coord[2]} {color[0]} {color[1]} {color[2]} {255}\n'

    with open(path, 'w', encoding='ascii') as f:
        f.write(header)
        f.write(vertex_info)
        f.write(cell_info)

    return

change these paths to your own

In [6]:
# change these paths to your own
file = '.datasets/teeth3ds/sample/upper/YBSESUN6/YBSESUN6_upper.obj' # data path
save_path = 'tmp/YBSESUN6_upper_gt.ply' # save path

In [8]:
file_ls = [
    "016FSM14_lower",
    "NQJYZS60_upper",
    "CXAJM3O9_lower",
    "Z83V9A9D_lower",
    "ZD89X7G1_lower",
    "019NUXJV_lower",
    "01KEK90A_lower",
    "MNIAB8K3_lower",
    "IUIE4BYI_lower",
    "4W9X0QQI_upper",
    "C3TQ47Z0_upper",
    "01CDZ2WA_upper",
    "B5GFZIRW_lower",
    "4J24X0ES_upper",
    "01K3G866_lower",
    "0199XT22_upper",
    "NX1SXEJY_lower",
    "6VX3OJFR_lower",
    "01MCNDR0_upper",
    "DC8VMT30_lower",
    "MXWIBTGF_lower",
    "3EU06ZN9_upper",
    "019M7KEN_upper",
    "JRG6Y6E0_upper",
    "S0AON6PZ_lower",
]

In [9]:
for file_name in tqdm(file_ls, desc="Processing examples"):
    file_view = file_name.split('_')[1]
    file_id = file_name.split('_')[0]
    file = f'.datasets/teeth3ds/{file_view}/{file_id}/{file_name}.obj'  # data path
    save_path = f'tmp/typical_examples/{file_name}_gt.ply'  # save path
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    mesh = trimesh.load(file)

    with open(file.replace('.obj', '.json')) as f:
        data = json.load(f)
    labels = np.array(data["labels"])
    labels = labels[mesh.faces]
    labels = labels[:, 0]
    labels = np.array([FDI2label[label] for label in labels])


    mask = []
    for label in labels:
        if 'upper' in file:
            color = label2color_upper[label][2]  # label 是单个 int
        elif 'lower' in file:
            color = label2color_lower[label][2]
        mask.append(color)
    mask = np.array(mask, dtype=np.uint8)  # shape: (N, 3)

    # get vertex mask              # shape: (n_vertices, 3)

    vertex_labels = face_labels_to_vertex_labels(mesh.faces, labels, len(mesh.vertices))
    vertex_mask = []
    for label in vertex_labels:
        if 'upper' in file:
            color = label2color_upper[label][2]  # label 是单个 int
        elif 'lower' in file:
            color = label2color_lower[label][2]
        vertex_mask.append(color)
    vertex_mask = np.array(vertex_mask, dtype=np.uint8)  # shape: (N, 3)

    point_coords = mesh.vertices                # shape: (n_vertices, 3)
    face_info = mesh.faces                      # shape: (n_faces, 3)
    output_pred_ply(mask, None, save_path, point_coords, face_info, vertex_mask)



Processing examples: 100%|██████████| 25/25 [00:55<00:00,  2.21s/it]
