In [None]:
%cd /ibex/user/slimhy/PADS/code
from datasets.metadata import hex_to_class, COMPAT_TRANSFORMS
from datasets.sampling import normalize_pc, sample_surface
from util.mesh import CUDAMesh
from tqdm.notebook import tqdm

import trimesh
import matplotlib.cm as cm
import numpy as np
import pickle
import os



def generate_colormap_colors(n, colormap_name="viridis", alpha=1.0):
    colormap = cm.get_cmap(colormap_name)
    colors = [(*colormap(i / (n - 1))[:3], alpha) for i in range(n)]
    return colors


def visualize_pointcloud(
    stacked_points,
    point_radius=0.005,
    colormap="viridis",
    alpha=1.0,
    points_per_part=128,
):
    """
    Create a scene combining the main mesh, its bounding boxes, and points for each part.
    Bounding boxes and corresponding points share the same color.
    """
    scene = trimesh.Scene()

    # Generate colors for parts
    n_parts = len(stacked_points)
    colors = generate_colormap_colors(n_parts, colormap_name=colormap, alpha=alpha)

    for i, color in enumerate(colors):
        # Add points for this part using the same color
        part_points = stacked_points[i]
        
        # Plot every point as a sphere
        for point in part_points[:points_per_part]:
            sphere = trimesh.creation.uv_sphere(radius=point_radius)
            sphere.apply_translation(point)
            sphere.visual.face_colors = np.array(color) * 255
            scene.add_geometry(sphere)

    return scene


def center_trimesh_mesh(mesh):
    """
    Center the trimesh mesh.
    """
    mesh_center = mesh.centroid
    mesh.apply_translation(-mesh_center)
    return mesh


def to_canonical(mesh):
    """
    Center the mesh and align it to the canonical orientation.
    """
    seg_t = mesh.bounding_box_oriented.transform
    rot_t = seg_t[:3, :3]
    rot_t_inv = rot_t.T
    rot_t_inv = np.pad(rot_t_inv, (0, 1), constant_values=0)
    mesh.apply_transform(rot_t_inv)
    mesh = center_trimesh_mesh(mesh)
    return mesh

In [None]:
part_dir = "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/"
all_objs = os.listdir(part_dir)
all_objs = [m for m in all_objs if m.endswith(".pkl")]
all_objs.sort()


def load_part_meshes(f_name):
    # Load the pkl
    with open(f_name, "rb") as f:
        data = pickle.load(f)
    
    shape_cls = f_name.split("/")[-1].split("_")[0]
    shape_cls = hex_to_class(shape_cls)

    if shape_cls in COMPAT_TRANSFORMS:
        mat = np.array(COMPAT_TRANSFORMS[shape_cls])
        mat = np.pad(mat, (0, 1))
        mat[3, 3] = 1
    else:
        mat = np.eye(4)

    for mesh_key, mesh in data.items():
        mesh.apply_transform(mat)
    return data


In [None]:
N_POINTS_PER_SHAPE = 2**17
OUT_DIR = "/ibex/project/c2273/PADS/3DCoMPaT_occ/parts"


for obj_name in tqdm(all_objs):
    mesh_segs = load_part_meshes(part_dir + obj_name)

    all_points = {}
    for (k, v) in list(mesh_segs.items()):
        v = to_canonical(v)
        p = sample_surface(CUDAMesh.from_trimesh(v), N_POINTS_PER_SHAPE)
        p = normalize_pc(p, method="per_axis")
        all_points[k] = p
        
    # Save the points
    out_f_name = OUT_DIR + "/" + obj_name.replace(".pkl", ".npy")
    np.save(out_f_name, all_points)