In [1]:
import open3d as o3d
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
import os

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Load 2D predicted mask image for 3 cases (up, in, out):

In [233]:
def load_pred_img(pred2D_path, mesh_name):
    
    # convert to color range [0, 1]
    pred_img_up = cv2.imread(os.path.join(pred2D_path, f"{mesh_name}_0.png"), cv2.IMREAD_UNCHANGED)/255
    pred_img_in = cv2.imread(os.path.join(pred2D_path, f"{mesh_name}_1.png"), cv2.IMREAD_UNCHANGED)/255
    pred_img_out = cv2.imread(os.path.join(pred2D_path, f"{mesh_name}_2.png"), cv2.IMREAD_UNCHANGED)/255

    print(f"pred_img shapes: {pred_img_up.shape}, {pred_img_in.shape}, {pred_img_out.shape}")

    return pred_img_up, pred_img_in, pred_img_out

mesh_name = "001602"
pred2D_dir = "D:\sunny\Codes\DPS\data_png\pred_0727"
info_dir = "D:\sunny\Codes\DPS\data_png\info"
label_mesh_dir = "D:\sunny\Codes\DPS\data_teethseg\label"


pred_img_up, pred_img_in, pred_img_out = load_pred_img(pred2D_dir, mesh_name)
print(np.max(pred_img_up), np.max(pred_img_in), np.max(pred_img_out))

pred_img shapes: (512, 768), (256, 2048), (256, 2048)
1.0 1.0 1.0


### Load reconstruction information:

Each file is a dictionary with 6 key components:
- "uvpx_up": (num_vert, 2)
- "uvpx_in": (num_vert, 2)
- "uvpx_out": (num_vert, 2)
each row is 2D UV coordinates

- "tri_up": (num_face_up, 3)
- "tri_in": (num_face_in, 3)
- "tri_out": (num_face_out, 3) 
each row is the 3 vertice indices which form the triangular face

In [234]:
info = np.load(os.path.join(info_dir, f"{mesh_name}.npz"))
uvpx_up = info["uvpx_up"]
uvpx_in = info["uvpx_in"]
uvpx_out = info["uvpx_out"]
print(uvpx_up.shape, uvpx_in.shape, uvpx_out.shape)

tri_up = info["tri_up"]
tri_in = info["tri_in"]
tri_out = info["tri_out"]
print(tri_up.shape, tri_in.shape, tri_out.shape)

# load mesh
mesh = o3d.io.read_triangle_mesh(os.path.join(label_mesh_dir, f"{mesh_name}.ply"))
vertices = np.asarray(mesh.vertices)
vert_GT_label = 1 - np.asarray(mesh.vertex_colors) # already in range [0, 1], need to be flipped so that 1 => plaque, 0 => tooth
print(vertices.shape)
print(vert_GT_label.shape)




(20561, 2) (20561, 2) (20561, 2)
(15995, 3) (9219, 3) (14650, 3)
(20561, 3)
(20561, 3)


### Extract predicted RGB and GT RGB for each triangle face

In [235]:
def get_tri_RGB(triangles, vertex_RGB):
    """ Get the RGB of each triangle face from the RGB of its 3 vertices"""
    tri_RGBs = []
    for triangle in triangles:
        colors_3vert = vertex_RGB[triangle]
        # Get the minimum color value for each channel
        tri_rgb = np.mean(colors_3vert, axis=0) # TODO: min vs mean!!!!
        tri_RGBs.append(tri_rgb[0])
    return np.array(tri_RGBs)


def get_tri_center_uv(triangles, uv_pixels):
    tri_center_uv = np.mean(uv_pixels[triangles], axis=1)
    return tri_center_uv

def get_tri_pred_label(tri_uvpx, pred_img_label):
    tri_pred_label = []
    for uv in tri_uvpx:
        u, v = uv.astype(np.int32)
        tri_pred_label.append(pred_img_label[v, u])
    return np.array(tri_pred_label)

In [236]:
tri_uvpx_up = get_tri_center_uv(tri_up, uvpx_up)
tri_uvpx_in = get_tri_center_uv(tri_in, uvpx_in)
tri_uvpx_out = get_tri_center_uv(tri_out, uvpx_out)

tri_pred_label_up = get_tri_pred_label(tri_uvpx_up, pred_img_up)
tri_pred_label_in = get_tri_pred_label(tri_uvpx_in, pred_img_in)
tri_pred_label_out = get_tri_pred_label(tri_uvpx_out, pred_img_out)

tri_GT_labelGRB_up = get_tri_RGB(tri_up, vert_GT_label)
tri_GT_labelGRB_in = get_tri_RGB(tri_in, vert_GT_label)
tri_GT_labelGRB_out = get_tri_RGB(tri_out, vert_GT_label)

print(np.count_nonzero(tri_pred_label_up), np.count_nonzero(tri_pred_label_in), np.count_nonzero(tri_pred_label_out))
print(np.count_nonzero(tri_GT_labelGRB_up), np.count_nonzero(tri_GT_labelGRB_in), np.count_nonzero(tri_GT_labelGRB_out))

print(len(tri_pred_label_up), len(tri_pred_label_in), len(tri_pred_label_out))

735 2049 2564
3515 3542 5798
15995 9219 14650


### Evaluative metrics

In [237]:

def compute_metrics_tri(gt_labels, pred_labels, class_id=1):
    """ Compute the IoU and Dice scores for the triangles """
    if class_id ==1: # Class 1 plaque
        # Convert RGB to binary scalar
        gt_labels_bi = (gt_labels > 0.25).astype(np.int32) 
        pred_labels_bi = (pred_labels > 0).astype(np.int32) # FIXME:
    else: # Class 0 non-plaque
        gt_labels_bi = (gt_labels < 0.25).astype(np.int32) 
        pred_labels_bi = (pred_labels == 0).astype(np.int32)
    

    intersection = np.sum(np.logical_and(gt_labels_bi, pred_labels_bi))
    union = np.sum(np.logical_or(gt_labels_bi, pred_labels_bi))
    iou = intersection / union

    intersection_bi = np.sum(np.logical_and(gt_labels_bi, pred_labels_bi))
    dice = 2 * intersection_bi / (np.sum(gt_labels_bi) + np.sum(pred_labels_bi))
    # if non-binary (3 channel)
    # dice = 2/3 * intersection / (np.sum(np.any(gt_labels!=0,axis=1)) + np.sum(np.any(pred_labels!=0,axis=1)))
    return iou, dice


iou_up, dice_up = compute_metrics_tri(tri_GT_labelGRB_up, tri_pred_label_up)
iou_in, dice_in = compute_metrics_tri(tri_GT_labelGRB_in, tri_pred_label_in)
iou_out, dice_out = compute_metrics_tri(tri_GT_labelGRB_out, tri_pred_label_out)

iou_mean = (len(tri_up)*iou_up + len(tri_in)*iou_in + len(tri_out)*iou_out)/len(mesh.triangles)
dice_mean = (len(tri_up)*dice_up + len(tri_in)*dice_in + len(tri_out)*dice_out)/len(mesh.triangles)
print(f"Class 1 IoU: {iou_mean:.3f}, Class 1 Dice: {dice_mean:.3f}")



iou_up0, dice_up0 = compute_metrics_tri(tri_GT_labelGRB_up, tri_pred_label_up, class_id=0)
iou_in0, dice_in0 = compute_metrics_tri(tri_GT_labelGRB_in, tri_pred_label_in, class_id=0)
iou_out0, dice_out0 = compute_metrics_tri(tri_GT_labelGRB_out, tri_pred_label_out, class_id=0)

iou_mean0 = (len(tri_up)*iou_up0 + len(tri_in)*iou_in0 + len(tri_out)*iou_out0)/len(mesh.triangles)
dice_mean0 = (len(tri_up)*dice_up0 + len(tri_in)*dice_in0 + len(tri_out)*dice_out0)/len(mesh.triangles)
print(f"Class 0 IoU: {iou_mean0:.3f}, Class 0 Dice: {dice_mean0:.3f}")


Class 1 IoU: 0.345, Class 1 Dice: 0.505
Class 0 IoU: 0.792, Class 0 Dice: 0.883


### Vertex Label evaluation

In [238]:
# Consider each vertex instead of face
def update_uv_pred_label(vert_pred_label, uv_pixel, vert_idx, pred_img_label):
    """ Get the predicted label for each vertex UV pixel coordinate 
    located on the respective predicted label image"""
    px_h, px_w = pred_img_label.shape[:2]
    for idx in vert_idx: # idx among all vertices (since uv_pixel is for all vertices)
        u, v = uv_pixel[idx]
        u = np.clip(int(u), 0, px_w-1)
        v = np.clip(int(v), 0, px_h-1)
        
        # if pred_img_label[v, u].any() != 0 and pred_img_label[v,u].any()!=1: # if not black nor white
        #     print(f"pred_img_label: {pred_img_label[v,u]}")
        
        # if vert_pred_label[idx].all() == np.array([-1,-1,-1]).all():
        vert_pred_label[idx] = pred_img_label[v, u]
      
        # else:
            # vert_pred_label[idx] = (pred_img_label[v, u] + vert_pred_label[idx]) /2 # average if already assigned
    return vert_pred_label


up_vert_idx = np.unique(tri_up.flatten())
in_vert_idx = np.unique(tri_in.flatten())
out_vert_idx = np.unique(tri_out.flatten())


# Initialize pred labels for vertices and triangles filled with -1
vert_pred_label = np.full((vertices.shape[0], 3), -1.0)
# Update the predicted labels for each vertex (outward, inward, upward)
vert_pred_label = update_uv_pred_label(vert_pred_label, uvpx_out, out_vert_idx, pred_img_out)
vert_pred_label = update_uv_pred_label(vert_pred_label, uvpx_in, in_vert_idx, pred_img_in)
vert_pred_label = update_uv_pred_label(vert_pred_label, uvpx_up, up_vert_idx, pred_img_up)


# Convert both GT and pred labels to binary scalar (0 or 1)
vert_pred_label_binary= np.zeros(vert_pred_label.shape[0])
vert_pred_label_binary[np.any(vert_pred_label > 0, axis=1)] = 1

vert_GT_label_binary= np.zeros(vert_GT_label.shape[0])
vert_GT_label_binary[np.any(vert_GT_label > 0, axis=1)] = 1



# Compare the predicted labels with the ground truth labels for all vertices from the entire mesh
def compute_metrics(gt_labels, pred_labels):
    """ Compute the IoU and Dice scores for the vertices """
    intersection = np.sum(np.logical_and(gt_labels, pred_labels))
    union = np.sum(np.logical_or(gt_labels, pred_labels))
    iou = intersection / union
    dice = 2 * intersection / (np.sum(gt_labels) + np.sum(pred_labels))
    # if non-binary (3 channel)
    # dice = 2/3 * intersection / (np.sum(np.any(gt_labels!=0,axis=1)) + np.sum(np.any(pred_labels!=0,axis=1)))
    return iou, dice

iou, dice = compute_metrics(vert_GT_label_binary, vert_pred_label_binary)

vert_GT_label_binary0 = np.logical_not(vert_GT_label_binary)
vert_pred_label_binary0 = np.logical_not(vert_pred_label_binary)
iou_0, dice_0 = compute_metrics(vert_GT_label_binary0, vert_pred_label_binary0)
print(f"Class 1 IoU: {iou:.3f}, Class 1 Dice: {dice:.3f}")
print(f"Class 0 IoU: {iou_0:.3f}, Class 0 Dice: {dice_0:.3f}")

Class 1 IoU: 0.369, Class 1 Dice: 0.539
Class 0 IoU: 0.795, Class 0 Dice: 0.886


### Visualization

In [239]:
# visualize the predicted label images in mesh form
def visualize_pred_labels(GTlabel_mesh, vert_GT_label_binary, vert_pred_label_binary): # assume vertex_pred_labels is in RGB format (0-1) non-binary
    """ Visualize the predicted labels on the mesh """
    mesh_pred = copy.deepcopy(GTlabel_mesh)
    vertices = np.asarray(mesh_pred.vertices)
    triangles = np.asarray(mesh_pred.triangles)
    colors = np.full((len(vertices), 3), 0.8) # grey: True negative

    
    for i in range(len(vertices)):
        if vert_pred_label_binary[i] == 1: 
            if vert_GT_label_binary[i] == 1: # True positive
                colors[i] = np.array([0, 1, 0]) # Green

            else: # False positive 误检
                colors[i] = np.array([0, 0, 0]) # Black
                
        else:
            if vert_GT_label_binary[i] == 1: # False negative 漏检
                colors[i] = np.array([1, 0, 0]) # Red

        
    mesh_pred.vertex_colors = o3d.utility.Vector3dVector(colors)

    return mesh_pred

mesh_pred = visualize_pred_labels(mesh, vert_GT_label_binary, vert_pred_label_binary)

axes = o3d.geometry.TriangleMesh.create_coordinate_frame(size=10, origin=[0, 0, 0])
o3d.visualization.draw_geometries([mesh_pred, axes])

In [240]:
# post-processing
    # Convert color format of pred mask image
    # (RGB vs BGR, 0-255 vs 0-1, resolution, binary?) 
    # --> pred_mask_img (2D)


##################### 2D-3D Backprojection #####################
# Load info from data_png folder
    # uv_pixel (mapping information)
    # tri (separation information)

# For each case (0-upward, 1-inward, 2-outward)
    # Extract triangle color (prediction) from pred_mask_img
    # --> pred_tri_color (_xxward)



############################ GT ########################
# Load GT label mesh (.ply) file
# For each case (0-upward, 1-inward, 2-outward)
    # Extract triangle color (ground truth) from vertex color
    # --> gt_tri_color (_xxward)





############################ Evaluate ########################
# Compute IoU, Dice for each case

# Visualize the results
