In [None]:
%cd /ibex/user/slimhy/PADS/code/

import argparse
import numpy as np
from datasets.shapeloaders import CoMPaTSegmentDataset
from util.misc import dump_pickle
from datasets.metadata import COMPAT_MATCHED_CLASSES


N_POINTS_PER_SHAPE = 2**17
OUT_PATH = "/ibex/project/c2273/3DCoMPaT/manifold_points"
SAMPLES_PER_DATASET = 16


def get_datasets(active_class):
    compat_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=N_POINTS_PER_SHAPE,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=False,
        remove_small_parts=False,
    )

    compat_part_drop_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=N_POINTS_PER_SHAPE,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=True,
        random_part_drop=True,
        n_parts_to_drop=1,
        remove_small_parts=False,
    )

    compat_random_aug_rotation_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=N_POINTS_PER_SHAPE,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=True,
    )

    compat_random_aug_no_rotation_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=N_POINTS_PER_SHAPE,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=False,
    )

    compat_random_all_aug_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=N_POINTS_PER_SHAPE,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=True,
        random_part_drop=True,
        n_parts_to_drop=1,
    )

    return {
        "orig": compat_dataset,
        "part_drop": compat_part_drop_dataset,
        "rand_rot": compat_random_aug_rotation_dataset,
        "rand_no_rot": compat_random_aug_no_rotation_dataset,
        "all_aug": compat_random_all_aug_dataset,
    }


def get_points(dataset, obj_k=0):
    surface_points, occs, bbs = next(dataset[obj_k])
    return surface_points, occs, bbs

def export_dataset_entry(datasets, obj_k, out_path):
    for aug_id, dataset in datasets.items():
        for sample_id in range(SAMPLES_PER_DATASET):
            surface_points, occs, bbs = get_points(dataset, obj_k=obj_k)
            
            model_id = dataset.get_model_id()
            
            sample_code = f"{model_id}_{aug_id}_{sample_id}"

            # Store the points
            np.save(f"{out_path}/{sample_code}_points", surface_points.cpu().numpy())

            # Store the occupancy grid
            np.save(f"{out_path}/{sample_code}_occs", occs.cpu().numpy())
            
            # Store the transformation matrix
            np.save(f"{out_path}/{sample_code}_transformation", dataset.transform_mat)

            # Store the bounding boxes
            dump_pickle(bbs, f"{out_path}/{sample_code}_bbs.pkl")

            if aug_id == "orig":
                break

def main(process_id, max_process):
    for active_class in COMPAT_MATCHED_CLASSES:
        print(f"Processing class {active_class}")

        # Get all datasets
        all_datasets = get_datasets(active_class)

        # Get the number of objects in the dataset
        num_objects = len(all_datasets["orig"])

        # Determine the number of processes to use for this class
        processes_for_class = min(max_process, num_objects)

        # If the current process_id is greater than or equal to the number of objects,
        # this process doesn't need to do anything for this class
        if process_id >= processes_for_class:
            print(f"Process {process_id} skipping class {active_class}")
            continue

        # Calculate the slice for this process
        base_slice_size = num_objects // processes_for_class
        remainder = num_objects % processes_for_class

        # Distribute the remainder among the first 'remainder' processes
        start_idx = process_id * base_slice_size + min(process_id, remainder)
        end_idx = start_idx + base_slice_size + (1 if process_id < remainder else 0)

        to_process = range(start_idx, end_idx)
        
        print(f"Process {process_id} processing range: {to_process}")

        # Iterate over all objects in all datasets jointly
        for k in to_process:
            export_dataset_entry(all_datasets, k, OUT_PATH)
            print(f"Processed object {k + 1}/{num_objects}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extract points from the CoMPaT dataset with part bounding boxes and occupancy grids"
    )
    parser.add_argument("process_id", type=int, help="ID of the current process")
    parser.add_argument("max_process", type=int, help="Total number of processes")
    args = parser.parse_args()

    main(args.process_id, args.max_process)

In [None]:
import k3d
import numpy as np

def plot_mesh_bbs(mesh, bbs):
    # Use k3d colormaps
    unique_parts = np.array(range(len(bbs)))
    col_map = k3d.helpers.map_colors(unique_parts, k3d.colormaps.basic_color_maps.Rainbow)
    col_map = [int(c) for c in col_map]

    # Create the plot
    plot = k3d.plot()

    plot += k3d.mesh(np.array(mesh.vertices), np.array(mesh.faces), color=0xefefef)
    plot += k3d.mesh(mesh.bounding_box_oriented.vertices, mesh.bounding_box_oriented.faces, color=0xefefef, opacity=0.1)
    for k, bb in enumerate(bbs):
        bb_mesh = bb[1]
        # Set color with low alpha
        plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=col_map[k], opacity=0.5)
        
    return plot

In [None]:
from util.misc import load_pickle

aug_type = "all_aug_15"
mesh = all_datasets["all_aug"].mesh.trimesh_mesh
bbs = load_pickle(f"{OUT_PATH}/0c_000_{aug_type}_bbs.pkl")

plot_mesh_bbs(mesh, bbs)

In [None]:
len(all_datasets["all_aug"])

In [3]:
import os

global PROCESSED_MODELS
all_files = os.listdir("/ibex/project/c2273/3DCoMPaT/latents/")
PROCESSED_MODELS = set(filename[:6] for filename in all_files)

len(PROCESSED_MODELS)

7999

In [2]:
"""
Extract points from the CoMPaT dataset with part bounding boxes and occupancy grids.
"""
import os
import sys
sys.path.append("/ibex/user/slimhy/PADS/code")
from datasets.shapeloaders import CoMPaTSegmentDataset
from datasets.metadata import COMPAT_MATCHED_CLASSES


def get_datasets(active_class):
    compat_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        n_points=2048,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=False,
        remove_small_parts=False,
    )

    return {
        "orig": compat_dataset,
    }

full_count = 0
for active_class in COMPAT_MATCHED_CLASSES:
    print(f"Processing class {active_class}")

    # Get all datasets
    all_datasets = get_datasets(active_class)
    num_objects = len(all_datasets["orig"])
    full_count += num_objects

Processing class airplane
Processing class bag
Processing class basket
Processing class bed
Processing class bench
Processing class bird_house
Processing class boat
Processing class cabinet
Processing class car
Processing class chair
Processing class dishwasher
Processing class dresser
Processing class faucet
Processing class jug
Processing class lamp
Processing class love_seat
Processing class ottoman
Processing class planter
Processing class shelf
Processing class skateboard
Processing class sofa
Processing class sports_table
Processing class stool
Processing class table
Processing class trashcan
Processing class vase
