In [1]:
%cd /ibex/user/slimhy/PADS/code/
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch
import trimesh

import util.misc as misc
import util.s2vs as s2vs

from datasets.shapeloaders import CoMPaTManifoldDataset, PartNetManifoldDataset, SingleManifoldDataset
from util.misc import d_GPU, show_side_by_side

/ibex/user/slimhy/PADS/code


In [2]:
from datasets.metadata import (
    COMPAT_CLASSES,
    int_to_hex,
)
import os
from util.misc import CUDAMesh


class CoMPaTSegDataset(SingleManifoldDataset):
    """
    Sampling from a 3DCoMPaT manifold mesh dataset with segmentation labels.
    """

    def __init__(
        self,
        *args,
        seg_dir,
        **kwargs,
    ):
        self.seg_dir = seg_dir
        super().__init__(*args, **kwargs)
        if self.normalize:
            print(
                "normalize=True but 3DCoMPaT shapes are already normalized to their bounding boxes."
            )

    def get_mesh(self, idx=None):
        """
        Load the mesh from the given index.
        """
        if idx is None:
            idx = self.mesh_idx

        if self.mesh is None:
            self.mesh = CUDAMesh.load(self.obj_files[idx], to_cuda=self.to_cuda)

            # Print an alert if the mesh is not watertight
            if not self.mesh.is_watertight:
                print("Mesh is not watertight! Performing robust conversion...")
                obj_base_name = os.path.basename(self.obj_files[idx])
                robust_pcu_to_manifold(self.obj_files[idx], "/tmp/" + obj_base_name)
                # Try to load and test if watertight
                self.mesh = CUDAMesh.load("/tmp/" + obj_base_name, to_cuda=self.to_cuda)
                if not self.mesh.is_watertight:
                    raise ValueError("Watertight conversion failed!")

                # Replace the original mesh with the watertight one
                # Write to original file
                self.mesh.export(self.obj_files[idx])
                print("Watertight conversion successful!")

            # Decimate the mesh if it has too many faces
            if self.decimate and len(self.mesh.faces) > self.MAX_FACES:
                # The ratio is the percentage of faces to REMOVE
                ratio = 1 - self.MAX_FACES / len(self.mesh.faces)
                self.mesh = decimate_mesh(self.mesh, ratio)

        return self.mesh
    
    def set_seg(self, idx=None):
        """
        Load the segmentation from the given index.
        """
        pass

    def __getitem__(self, idx):
        if self.mesh_idx != idx or self.mesh is None:
            self.mesh_idx = idx
            self.get_mesh(idx)

        # Optionally: first sample n_points first
        # And simply serve the same points for the rest of the iterations
        if self.sample_first:
            # Use batch sampling
            n_batches = self.n_points // self.MAX_SAMPLE_SIZE
            all_points, all_occs = [], []
            for k in range(n_batches):
                if k % 4 == 0:
                    print("Sampling batch [%d/%d]" % (k + 1, n_batches))
                points, occs = self.sampling_fn(self.mesh, self.MAX_SAMPLE_SIZE)
                all_points += [points]
                all_occs += [occs]
            print()
            points_idx = list(range(len(all_points)))

        # Resample the mesh
        for _ in range(self.max_it):
            if self.sample_first:
                rnd_idx = np.random.choice(points_idx)
                points = all_points[rnd_idx]
                occs = all_occs[rnd_idx]
            else:
                points, occs = self.sampling_fn(self.mesh, self.n_points)

            # Optionally: normalize the point cloud
            if self.normalize:
                points = normalize_pc(points)
            yield points, occs

    def init_class_objs(self):
        """
        Set the list of objects for a given class/split.
        """

        def join_all(in_dir, files):
            return sorted([os.path.join(in_dir, f) for f in files])

        compat_cls_code = int_to_hex(COMPAT_CLASSES[self.shape_cls])
        obj_files = os.listdir(self.obj_dir)
        # obj_files = [os.path.join(self.obj_dir, f) for f in obj_files]
        obj_files = [
            f for f in obj_files if f.endswith(".obj") and compat_cls_code + "_" in f
        ]

        if self.split == "all":
            self.obj_files = join_all(self.obj_dir, obj_files)
            self.seg_files = join_all(self.seg_dir, [f.replace(".obj", ".gltf") for f in obj_files])
            return

        # Open the split metadata
        pwd = os.path.dirname(os.path.realpath(__file__))
        split_dict = json.load(open(os.path.join(pwd, "CoMPaT", "split.json")))

        # Filter split meshes
        obj_files = [
            f
            for f in obj_files
            if f.split(".")[0] in split_dict[self.split]
        ]

        self.obj_files = join_all(self.obj_dir, obj_files)
        self.seg_files = join_all(self.seg_dir, [f.replace(".obj", ".gltf") for f in obj_files])

In [3]:
ACTIVE_CLASS = "chair"
OBJ_DIR  = "/ibex/project/c2273/3DCoMPaT/obj_manifold/"
SEG_DIR  = "/ibex/project/c2273/3DCoMPaT/gltf/"
ZIP_PATH = "/ibex/project/c2273/3DCoMPaT/3DCoMPaT_ZIP.zip"
OBJ_ID = 0

dataset = CoMPaTSegDataset(
    OBJ_DIR,
    ACTIVE_CLASS,
    10000,
    seg_dir=SEG_DIR,
    normalize=False,
    sampling_method="surface",
    to_cuda=False
)
surface_points, _ = next(dataset[OBJ_ID])

In [4]:
import os
import zipfile
import json
from datasets.CoMPaT.utils3D import gltf

zip_f = zipfile.ZipFile(ZIP_PATH, "r")
textures_map = json.load(zip_f.open("textures_map.json", "r"))

class ZipTextureResolver(trimesh.resolvers.FilePathResolver):
    """
    Resolve texture files from the input zip.
    """
    def __init__(self, zip_f):
        self.zip_f = zip_f

    def get(self, file_path):
        return self.zip_f.open(file_path, "r").read()

def load_gltf(gltf_f):
    gltf_f = gltf.apply_placeholder(open(gltf_f), textures_map)
    return trimesh.load(
        gltf_f,
        file_type=".gltf",
        force="scene",
        resolver=ZipTextureResolver(zip_f),
    )
    
gltf_file = os.path.join(SEG_DIR, dataset.obj_files[0].split("/")[-1].split(".")[0] + ".gltf")
gltf_model = load_gltf(gltf_file)

In [6]:
import k3d

# Silence traittypes warnings
import warnings
warnings.filterwarnings("ignore")


# Create the K3D plot
plot = k3d.plot()
mesh = dataset.get_mesh().trimesh_mesh
plot += k3d.mesh(mesh.vertices, mesh.faces, color=0xe1e0df)
plot += k3d.points(surface_points, point_size=0.01, color=0xe1e0df)

# Display the plot
plot.display()

Output()