In [1]:
import argparse
import random

import numpy as np
import open3d as o3d
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
import cv2

def normalize(x):
    return x / np.linalg.norm(x)


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, pos], 1)
    return m


def completion_ratio(gt_points, rec_points, dist_th=0.05):
    gen_points_kd_tree = KDTree(rec_points)
    distances, _ = gen_points_kd_tree.query(gt_points)
    comp_ratio = np.mean((distances < dist_th).astype(np.float64))
    return comp_ratio


def accuracy(gt_points, rec_points):
    gt_points_kd_tree = KDTree(gt_points)
    distances, _ = gt_points_kd_tree.query(rec_points)
    acc = np.mean(distances)
    return acc, distances


def completion(gt_points, rec_points):
    gt_points_kd_tree = KDTree(rec_points)
    distances, _ = gt_points_kd_tree.query(gt_points)
    comp = np.mean(distances)
    return comp, distances

def write_vis_pcd(file, points, colors):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.io.write_point_cloud(file, pcd)

def get_align_transformation(rec_meshfile, gt_meshfile):
    """
    Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
    """    
    o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
    o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
    o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
    o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
    trans_init = np.eye(4)
    threshold = 0.1
    reg_p2p = o3d.pipelines.registration.registration_icp(
        o3d_rec_pc, o3d_gt_pc, threshold, trans_init,
        o3d.pipelines.registration.TransformationEstimationPointToPoint())
    transformation = reg_p2p.transformation
    return transformation


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [31]:
pred_path = '/home/wongyeom/workspace/objsharp/replica_eval/munhyen/model.obj'
gt_path = '/home/wongyeom/workspace/objsharp/replica_eval/munhyen/1.obj'

mesh_rec = trimesh.load(pred_path, process=False)
mesh_gt = trimesh.load(gt_path, process=False)

print(mesh_rec)
print(mesh_gt)


<trimesh.Trimesh(vertices.shape=(29842, 3), faces.shape=(59508, 3), name=`1.obj`)>
<trimesh.Trimesh(vertices.shape=(29842, 3), faces.shape=(59508, 3), name=`1.obj`)>


In [32]:
to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt)
mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T
mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T

In [33]:
min_points = mesh_gt.vertices.min(axis=0) * 1.005
max_points = mesh_gt.vertices.max(axis=0) * 1.005

In [34]:
mask_min = (mesh_rec.vertices - min_points[None]) > 0
mask_max = (mesh_rec.vertices - max_points[None]) < 0

In [35]:
print(mask_min)
print(mask_max)

[[ True  True  True]
 [ True  True  True]
 [ True  True  True]
 ...
 [ True  True  True]
 [ True  True  True]
 [ True  True  True]]
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]
 ...
 [ True  True  True]
 [ True  True  True]
 [ True  True  True]]


In [36]:
mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1)
face_mask = mask[mesh_rec.faces].all(axis=1)

In [37]:
mesh_rec.update_vertices(mask)
mesh_rec.update_faces(face_mask)

print(mesh_rec)

<trimesh.Trimesh(vertices.shape=(29842, 3), faces.shape=(59508, 3), name=`1.obj`)>


In [38]:
rec_pc = trimesh.sample.sample_surface(mesh_rec, 200000)
rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])

gt_pc = trimesh.sample.sample_surface(mesh_gt, 200000)
gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])

print(rec_pc_tri, gt_pc_tri)

<trimesh.PointCloud(vertices.shape=(200000, 3))> <trimesh.PointCloud(vertices.shape=(200000, 3))>


In [39]:
completion_ratio_rec = completion_ratio(
        gt_pc_tri.vertices, rec_pc_tri.vertices)
    
precision_ratio_rec = completion_ratio(
    rec_pc_tri.vertices, gt_pc_tri.vertices)

fscore = 2 * precision_ratio_rec * completion_ratio_rec / (completion_ratio_rec + precision_ratio_rec)
    
print(fscore)

0.99999749999375
