In [1]:
import torch
import numpy as np
import open3d as o3d
from tqdm import tqdm
import cv2
from PIL import Image

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Read RGBD and camera parameters 


In [104]:
def read_rgbd_and_pose(data_root):
    poses, intrinsics, rgbs, depths = [], [], [], []
    # ToDo: read rgbd and camera parameters
    pass
    return poses, intrinsics, rgbs, depths

In [114]:
ROOT_DIR = 'YOUR PROJECT DIR'
poses, intrinsics, rgbs, depths = read_rgbd_and_pose(ROOT_DIR)
print(poses[0].shape, intrinsics[0].shape, rgbs[0].shape, depths[0].shape)
print(len(poses), len(intrinsics), len(rgbs), len(depths))
H, W = rgbs[0].shape[:2]

(4, 4) (4, 4) (800, 800, 3) (800, 800)
100 100 100 100


### Extract mesh from RGBD

In [107]:
def post_process_mesh(mesh, cluster_to_keep=1):
    """
    Post-process a mesh to filter out floaters and disconnected parts
    """
    import copy
    # print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
    print("post processing")
    mesh_0 = copy.deepcopy(mesh)
    triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())

    triangle_clusters = np.asarray(triangle_clusters)
    cluster_n_triangles = np.asarray(cluster_n_triangles)
    cluster_area = np.asarray(cluster_area)

    n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
    n_cluster = max(n_cluster, 50) # filter meshes smaller than 50

    triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
    mesh_0.remove_triangles_by_mask(triangles_to_remove)
    mesh_0.remove_unreferenced_vertices()
    mesh_0.remove_degenerate_triangles()
    return mesh_0


class MeshExtractor(object):
    def __init__(self, poses, intrinsics, rgbs, depths, depth_trunc=3.0, sdf_trunc=0.02, voxel_size=0.004):
        self.poses = poses
        self.intrinsics = intrinsics
        self.rgbs = rgbs
        self.depths = depths
        self.voxel_size = voxel_size
        self.sdf_trunc = sdf_trunc
        self.depth_trunc = depth_trunc
        print('voxel_size: ', self.voxel_size)
        print('sdf_trunc: ', self.sdf_trunc)
        print('depth_trunc: ', self.depth_trunc)
    
    def extract_mesh(self, post_process=True):
        mesh = self.extract_mesh_bounded(voxel_size=self.voxel_size, sdf_trunc=self.sdf_trunc, depth_trunc=self.depth_trunc)
        if post_process:
            mesh = post_process_mesh(mesh)
        return mesh

    def extract_mesh_bounded(self, voxel_size, sdf_trunc, depth_trunc):
        """
        Perform TSDF fusion given a fixed depth range.
        voxel_size: the voxel size of the volume
        sdf_trunc: truncation value
        depth_trunc: maximum depth range, should depended on the scene's scales
        return o3d.mesh
        """
        volume = o3d.pipelines.integration.ScalableTSDFVolume(
            voxel_length= voxel_size,
            sdf_trunc=sdf_trunc,
            color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
        )

        for i, pose in tqdm(enumerate(self.poses), desc="TSDF integration"):
            intrinsic, rgb, depth = self.intrinsics[i], self.rgbs[i], self.depths[i]
            # make open3d rgbd
            rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
                o3d.geometry.Image(np.asarray(rgb, order="C", dtype=np.uint8)), # rgb should be in 0-255, shape [H, W, 3]
                o3d.geometry.Image(np.asarray(depth)), # depth should be in meters, shape [H, W, 1]
                depth_trunc = depth_trunc, convert_rgb_to_intensity=False,
                depth_scale = 1.0
            )
            intrinsic = o3d.camera.PinholeCameraIntrinsic(
                width=W,
                height=H,
                cx = intrinsic[0,2].item(),
                cy = intrinsic[1,2].item(), 
                fx = intrinsic[0,0].item(), 
                fy = intrinsic[1,1].item()
            )
            extrinsic = np.linalg.inv(pose)
            volume.integrate(rgbd, intrinsic=intrinsic, extrinsic=extrinsic)

        mesh = volume.extract_triangle_mesh()
        return mesh

In [108]:
mesh_extractor = MeshExtractor(poses, intrinsics, rgbs, depths)
mesh = mesh_extractor.extract_mesh(post_process=False)
recon_path = "mesh.ply"
o3d.io.write_triangle_mesh(recon_path, mesh)

voxel_size:  0.004
sdf_trunc:  0.02
depth_trunc:  3.0


TSDF integration: 100it [00:01, 77.62it/s]


True

In [109]:
def visualize_point_cloud(data):
    """
    Visualizes a point cloud using Open3D. Supports N*3 and N*6 point clouds,
    and accepts both NumPy arrays and PyTorch tensors.

    :param data: A NumPy array or PyTorch tensor of shape (N, 3) or (N, 6).
                 For (N, 3), it represents the (x, y, z) coordinates of the points.
                 For (N, 6), it represents the (x, y, z, r, g, b) coordinates and colors of the points.
    """
    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()

    if data.shape[1] not in [3, 6]:
        raise ValueError("The input data must have shape (N, 3) or (N, 6).")

    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(data[:, :3])

    if data.shape[1] == 6:
        point_cloud.colors = o3d.utility.Vector3dVector(data[:, 3:])

    o3d.visualization.draw_geometries([point_cloud])

### Visualize the mesh as point cloud

In [133]:
recon_tgt_path = "recon_tgt.ply"
pcd = o3d.io.read_point_cloud(recon_tgt_path)
xyz, color = np.array(pcd.points), np.array(pcd.colors)
visualize_point_cloud(np.concatenate((xyz, color), axis=1))

In [None]:
pcd = o3d.io.read_point_cloud(recon_path)
xyz, color = np.array(pcd.points), np.array(pcd.colors)
visualize_point_cloud(np.concatenate((xyz, color), axis=1))

### Compute the reconstruction error with chamfer distance

In [None]:
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.io import load_ply
from pytorch3d.structures import Meshes

def compute_chamfer(recon_pts, gt_pts):
	with torch.no_grad():
		recon_pts = recon_pts.cuda()
		gt_pts = gt_pts.cuda()
		dist,_ = chamfer_distance(recon_pts, gt_pts, batch_reduction=None, single_directional=False)
		dist = dist.item()
	return dist


def compute_recon_error(recon_path, gt_path, n_samples=10000):
    verts, faces = load_ply(recon_path)
    recon_mesh = Meshes(verts=[verts], faces=[faces])
    verts, faces = load_ply(gt_path)
    gt_mesh = Meshes(verts=[verts], faces=[faces])

    gt_pts = sample_points_from_meshes(gt_mesh, num_samples=n_samples)
    recon_pts = sample_points_from_meshes(recon_mesh, num_samples=n_samples)
    return compute_chamfer(recon_pts, gt_pts) * 1000 # convert to mm

In [None]:
torch.manual_seed(0) # set seed
gt_path = "gt.ply"
error = compute_recon_error(recon_path, gt_path)
print('Chamfer Distance: {:.4f} mm'.format(error))