In [None]:
%cd /ibex/user/slimhy/PADS/code/
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch
import trimesh

import util.misc as misc
import util.s2vs as s2vs

from datasets.shapeloaders import CoMPaTManifoldDataset, PartNetManifoldDataset, SingleManifoldDataset
from util.misc import d_GPU, show_side_by_side

In [None]:
from datasets.metadata import (
    COMPAT_CLASSES,
    int_to_hex,
)
import os
from util.misc import CUDAMesh


class CoMPaTSegDataset(SingleManifoldDataset):
    """
    Sampling from a 3DCoMPaT manifold mesh dataset with segmentation labels.
    """

    def __init__(
        self,
        *args,
        seg_dir,
        **kwargs,
    ):
        self.seg_dir = seg_dir
        super().__init__(*args, **kwargs)
        if self.normalize:
            print(
                "normalize=True but 3DCoMPaT shapes are already normalized to their bounding boxes."
            )

    def get_mesh(self, idx=None):
        """
        Load the mesh from the given index.
        """
        if idx is None:
            idx = self.mesh_idx

        if self.mesh is None:
            self.mesh = CUDAMesh.load(self.obj_files[idx], to_cuda=self.to_cuda)

            # Print an alert if the mesh is not watertight
            if not self.mesh.is_watertight:
                print("Mesh is not watertight! Performing robust conversion...")
                obj_base_name = os.path.basename(self.obj_files[idx])
                robust_pcu_to_manifold(self.obj_files[idx], "/tmp/" + obj_base_name)
                # Try to load and test if watertight
                self.mesh = CUDAMesh.load("/tmp/" + obj_base_name, to_cuda=self.to_cuda)
                if not self.mesh.is_watertight:
                    raise ValueError("Watertight conversion failed!")

                # Replace the original mesh with the watertight one
                # Write to original file
                self.mesh.export(self.obj_files[idx])
                print("Watertight conversion successful!")

            # Decimate the mesh if it has too many faces
            if self.decimate and len(self.mesh.faces) > self.MAX_FACES:
                # The ratio is the percentage of faces to REMOVE
                ratio = 1 - self.MAX_FACES / len(self.mesh.faces)
                self.mesh = decimate_mesh(self.mesh, ratio)

        return self.mesh
    
    def set_seg(self, idx=None):
        """
        Load the segmentation from the given index.
        """
        pass

    def __getitem__(self, idx):
        if self.mesh_idx != idx or self.mesh is None:
            self.mesh_idx = idx
            self.get_mesh(idx)

        # Optionally: first sample n_points first
        # And simply serve the same points for the rest of the iterations
        if self.sample_first:
            # Use batch sampling
            n_batches = self.n_points // self.MAX_SAMPLE_SIZE
            all_points, all_occs = [], []
            for k in range(n_batches):
                if k % 4 == 0:
                    print("Sampling batch [%d/%d]" % (k + 1, n_batches))
                points, occs = self.sampling_fn(self.mesh, self.MAX_SAMPLE_SIZE)
                all_points += [points]
                all_occs += [occs]
            print()
            points_idx = list(range(len(all_points)))

        # Resample the mesh
        for _ in range(self.max_it):
            if self.sample_first:
                rnd_idx = np.random.choice(points_idx)
                points = all_points[rnd_idx]
                occs = all_occs[rnd_idx]
            else:
                points, occs = self.sampling_fn(self.mesh, self.n_points)

            # Optionally: normalize the point cloud
            if self.normalize:
                points = normalize_pc(points)
            yield points, occs

    def init_class_objs(self):
        """
        Set the list of objects for a given class/split.
        """

        def join_all(in_dir, files):
            return sorted([os.path.join(in_dir, f) for f in files])

        compat_cls_code = int_to_hex(COMPAT_CLASSES[self.shape_cls])
        obj_files = os.listdir(self.obj_dir)
        # obj_files = [os.path.join(self.obj_dir, f) for f in obj_files]
        obj_files = [
            f for f in obj_files if f.endswith(".obj") and compat_cls_code + "_" in f
        ]

        if self.split == "all":
            self.obj_files = join_all(self.obj_dir, obj_files)
            self.seg_files = join_all(self.seg_dir, [f.replace(".obj", ".gltf") for f in obj_files])
            return

        # Open the split metadata
        pwd = os.path.dirname(os.path.realpath(__file__))
        split_dict = json.load(open(os.path.join(pwd, "CoMPaT", "split.json")))

        # Filter split meshes
        obj_files = [
            f
            for f in obj_files
            if f.split(".")[0] in split_dict[self.split]
        ]

        self.obj_files = join_all(self.obj_dir, obj_files)
        self.seg_files = join_all(self.seg_dir, [f.replace(".obj", ".gltf") for f in obj_files])

In [None]:
import os
import zipfile
import json
from datasets.CoMPaT.compat3D import SegmentedMeshLoader
from datasets.CoMPaT.utils3D.plot import label_to_RGB, FINE_RGB_RANGE

ACTIVE_CLASS = "chair"
OBJ_DIR  = "/ibex/project/c2273/3DCoMPaT/obj_manifold/"
SEG_DIR  = "/ibex/project/c2273/3DCoMPaT/gltf/"
ZIP_PATH = "/ibex/project/c2273/3DCoMPaT/3DCoMPaT_ZIP.zip"
META_DIR = "/ibex/project/c2273/3DCoMPaT/3DCoMPaT-v2/metadata"
OBJ_ID = 10


# Instantiating manifold dataset
manifold_dataset = CoMPaTSegDataset(
    OBJ_DIR,
    ACTIVE_CLASS,
    10000,
    seg_dir=SEG_DIR,
    normalize=False,
    sampling_method="surface",
    to_cuda=False
)
surface_points, _ = next(manifold_dataset[OBJ_ID])
manifold_mesh = manifold_dataset.get_mesh(OBJ_ID)


# Instantiating segment dataset
seg_dataset = SegmentedMeshLoader(
    zip_path=ZIP_PATH,
    meta_dir=META_DIR,
    split="train",
    shuffle=True,
    seed=0,
)

model_id = manifold_dataset.obj_files[OBJ_ID].split("/")[-1].split(".")[0]
mesh_id = seg_dataset.get_model_index(model_id)
mesh_map = seg_dataset[mesh_id]

In [None]:
import numpy as np
from scipy.spatial import cKDTree


def sample_face_points(mesh, num_samples_per_face, random_sampling=True):
    """
    Sample points on the faces of a mesh.
    """
    num_faces = len(mesh.faces)
    total_samples = num_faces * num_samples_per_face
    
    # Get all vertices for each face
    face_vertices = mesh.vertices[mesh.faces]
    
    if random_sampling:
        # Generate random barycentric coordinates
        r1, r2 = np.random.random((2, total_samples))
        r1 = np.sqrt(r1)
        a = 1 - r1
        b = r1 * (1 - r2)
        c = r2 * r1
    else:
        # TODO: Fix!!
        a = np.full(total_samples, 1/3)
        b = np.full(total_samples, 1/3)
        c = np.full(total_samples, 1/3)

    # Reshape barycentric coordinates to match face_vertices shape
    barycentric_coords = np.column_stack([a, b, c]).reshape(num_faces, num_samples_per_face, 3)
    
    # Compute points on triangles
    points = (face_vertices[:, np.newaxis, :, :] * barycentric_coords[:, :, :, np.newaxis]).sum(axis=2)
    
    # Reshape to 2D array
    return points.reshape(-1, 3)


def sample_face_points_center(mesh):
    """
    Sample points on the center of each face efficiently.
    """
    num_faces = len(mesh.faces)
    
    # Initialize the output array
    points = np.zeros((num_faces, 3))
    
    for i, face in enumerate(mesh.faces):
        # Compute the mean of the face vertices
        points[i] = np.mean(mesh.vertices[face], axis=0)
    
    return points
    

def find_closest_meshes(mesh_list, pointcloud):
    num_samples_per_face = 50
    closest_mesh_indices = np.zeros(len(pointcloud), dtype=int)
    min_distances = np.full(len(pointcloud), np.inf)

    for i, mesh in enumerate(mesh_list):
        # Sample points on the surface of the mesh
        sampled_points = sample_face_points(mesh, num_samples_per_face, True)

        # Create a KD-tree for the current mesh's vertices
        tree = cKDTree(sampled_points)
        
        # Find the distance to the closest vertex for each point in the pointcloud
        distances, _ = tree.query(pointcloud)
        
        # Update closest_mesh_indices where this mesh is closer
        closer_points = distances < min_distances
        closest_mesh_indices[closer_points] = i
        min_distances[closer_points] = distances[closer_points]

    return closest_mesh_indices

In [None]:
import numpy as np
from multiprocessing import Pool
from functools import partial

# Silence traittypes warnings
import warnings
warnings.filterwarnings("ignore")

import k3d

# Use k3d colormaps
unique_parts = np.array(range(len(mesh_map)))
col_map = k3d.helpers.map_colors(unique_parts, k3d.colormaps.basic_color_maps.Rainbow)
col_map = {key: col_map[i] for i, key in enumerate(mesh_map)}

def get_point_colors(mesh_map, manifold_mesh):
    face_samples = sample_face_points_center(manifold_mesh.trimesh_mesh)
    col_list = [col_map[mesh_label] for mesh_label in mesh_map]
    all_idx = find_closest_meshes(mesh_map.values(), face_samples)
    point_colors = [col_list[i] for i in all_idx]
    
    return col_list, all_idx, point_colors

In [None]:
import numpy as np
import trimesh

def hex_to_rgb(hex_color):
    # Ensure the hex color is a positive integer
    hex_color = abs(hex_color)
    
    # Extract RGB components
    r = (hex_color >> 16) & 255
    g = (hex_color >> 8) & 255
    b = hex_color & 255

    return (r, g, b)


def assign_colors(mesh, face_color):
    mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)

    # Convert face_colors to RGB
    rgb_colors = np.repeat([hex_to_rgb(face_color)], len(mesh.faces), axis=0)

    mesh.faces = np.vstack((mesh.faces, np.fliplr(mesh.faces)))

    # Assign updated colors to the mesh faces (take into account the duplicated faces)
    mesh.visual.face_colors = np.vstack((rgb_colors, rgb_colors))
    # Set two-sided rendering
    mesh.visual.two_sided = True

    return mesh

def remove_faces_and_assign_colors(mesh, face_colors, faces_to_remove):
    mesh = mesh.copy()

    # Convert face_colors to RGB
    rgb_colors = np.array([hex_to_rgb(color) for color in face_colors])

    # Create a mask for faces to keep
    mask = np.ones(len(mesh.faces), dtype=bool)
    mask[faces_to_remove] = False

    # Remove faces
    mesh.update_faces(mask)
    mesh.remove_unreferenced_vertices()

    # Update face colors
    rgb_colors = rgb_colors[mask]

    mesh.faces = np.vstack((mesh.faces, np.fliplr(mesh.faces)))

    # Assign updated colors to the mesh faces (take into account the duplicated faces)
    mesh.visual.face_colors = np.vstack((rgb_colors, rgb_colors))
    # Set two-sided rendering
    mesh.visual.two_sided = True

    return mesh


def show_part(my_mesh, mesh_map, part_label):
    # get the indices of all faces not belonging to the part
    faces_to_hide = np.where(np.array(all_idx) != part_label)[0]

    mesh_labels = list(mesh_map.keys())
    face_colors = [col_map[mesh_label] for mesh_label in [mesh_labels[i] for i in all_idx]]

    # Apply the colors and hide specified faces
    my_mesh = remove_faces_and_assign_colors(my_mesh, face_colors, faces_to_hide)

    return my_mesh

# Get the point colors
col_list, all_idx, point_colors = get_point_colors(mesh_map, manifold_mesh)

part_id = 5
proj_part = show_part(manifold_mesh.trimesh_mesh, mesh_map, part_id)
orig_part_label, orig_part = list(mesh_map.items())[part_id]
orig_part = assign_colors(orig_part, col_map[orig_part_label])

show_side_by_side(proj_part, orig_part)