In [1]:
"""
Extract points from the CoMPaT dataset with part bounding boxes and occupancy grids.
"""

import os
import sys

sys.path.append("/ibex/user/slimhy/PADS/code")

import argparse
import numpy as np
import torch
from datasets.shapeloaders import CoMPaTSegmentDataset
from util.misc import dump_pickle
from datasets.metadata import COMPAT_MATCHED_CLASSES


N_POINTS_PER_SHAPE = 2**18
OUT_PATH = "/ibex/project/c2273/3DCoMPaT/packaged"
SAMPLES_PER_DATASET = 8
PROCESSED_MODELS = set()
SAMPLING_PARAMS = {
    "n_points": N_POINTS_PER_SHAPE,
    "sampling_method": "surface+near_surface",
    "near_surface_noise": 0.01,
}


def get_datasets(active_class):
    compat_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=False,
        remove_small_parts=False,
        **SAMPLING_PARAMS,
    )

    compat_part_drop_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=True,
        random_part_drop=True,
        n_parts_to_drop=1,
        remove_small_parts=False,
        **SAMPLING_PARAMS,
    )

    compat_random_aug_rotation_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=True,
        **SAMPLING_PARAMS,
    )

    compat_random_aug_no_rotation_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=False,
        **SAMPLING_PARAMS,
    )

    compat_random_all_aug_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
        random_rotation=True,
        random_part_drop=True,
        n_parts_to_drop=1,
        **SAMPLING_PARAMS,
    )

    return {
        "orig": compat_dataset,
        "part_drop": compat_part_drop_dataset,
        "rand_rot": compat_random_aug_rotation_dataset,
        "rand_no_rot": compat_random_aug_no_rotation_dataset,
        "all_aug": compat_random_all_aug_dataset,
    }


def get_points(dataset, obj_k=0):
    surface_points, occs, bbs = next(dataset[obj_k])
    return surface_points, occs, bbs


def initialize_processed_models(out_path):
    """
    Initialize the set of processed model IDs.
    """
    global PROCESSED_MODELS
    all_files = os.listdir(out_path)
    PROCESSED_MODELS = set(filename[:6] for filename in all_files)


def is_model_processed(model_id):
    """
    Check if the given model_id exists in the set of processed models.
    """
    global PROCESSED_MODELS
    return model_id in PROCESSED_MODELS


def export_dataset_entry(datasets, obj_k, out_path):
    global PROCESSED_MODELS
    model_id = datasets["orig"].get_model_id(obj_k)

    # Check if the model has already been processed
    if is_model_processed(model_id):
        print(f"Skipping model {model_id} as it has already been processed")
        return

    for aug_id, dataset in datasets.items():
        for sample_id in range(SAMPLES_PER_DATASET):
            sys.stdout.flush()

            all_points, occs, bbs = get_points(dataset, obj_k=obj_k)
            surface_points, near_surface_points = (
                all_points[0],
                all_points[1],
            )
            assert (
                torch.sum(occs[: N_POINTS_PER_SHAPE // 2].flatten())
                == N_POINTS_PER_SHAPE // 2
            )

            occs = occs[N_POINTS_PER_SHAPE // 2 :]
            sample_code = f"{model_id}_{aug_id}_{sample_id}"
            
            # Store the points
            np.save(
                f"{out_path}/{sample_code}_surface_points", surface_points.cpu().numpy()
            )
            np.save(
                f"{out_path}/{sample_code}_near_surface_points",
                near_surface_points.cpu().numpy(),
            )

            # Store the occupancy grid
            np.save(f"{out_path}/{sample_code}_occs", occs.cpu().numpy())

            # Store the transformation matrix
            np.save(f"{out_path}/{sample_code}_transformation", dataset.transform_mat)

            # Store the bounding boxes
            dump_pickle(bbs, f"{out_path}/{sample_code}_bbs.pkl")

            if aug_id == "orig":
                break

    # Add the processed model to the set
    PROCESSED_MODELS.add(model_id)
    print(f"Processed model {model_id}")


def main(process_id, max_process):
    for active_class in COMPAT_MATCHED_CLASSES:
        print(f"Processing class {active_class}")

        # Get all datasets
        all_datasets = get_datasets(active_class)
        return all_datasets
        
        # Get the number of objects in the dataset
        num_objects = len(all_datasets["orig"])

        # Determine the number of processes to use for this class
        processes_for_class = min(max_process, num_objects)

        # If the current process_id is greater than or equal to the number of objects,
        # this process doesn't need to do anything for this class
        if process_id >= processes_for_class:
            print(f"Process {process_id} skipping class {active_class}")
            continue

        # Calculate the slice for this process
        base_slice_size = num_objects // processes_for_class
        remainder = num_objects % processes_for_class

        # Distribute the remainder among the first 'remainder' processes
        start_idx = process_id * base_slice_size + min(process_id, remainder)
        end_idx = start_idx + base_slice_size + (1 if process_id < remainder else 0)

        to_process = range(start_idx, end_idx)

        print(f"Process {process_id} processing range: {to_process}")

        # Iterate over all objects in all datasets jointly
        for k in to_process:
            export_dataset_entry(all_datasets, k, OUT_PATH)
            print(f"Processed object {k + 1}/{num_objects}")


if __name__ == "__main__":
    # Initialize the set of processed models
    initialize_processed_models(OUT_PATH)

    all_datasets = main(0, 1)

    print("Done.")


No CUDA runtime is found, using CUDA_HOME='/sw/rl9g/cuda/12.2/rl9_binary'


Processing class airplane
Done.


In [2]:
all_datasets['orig'].sampling_fn

functools.partial(<function combine_samplings at 0x154449901750>, sampling_fns=[<function sample_surface_simple at 0x154449901630>, functools.partial(<function sample_near_surface at 0x1544499016c0>, noise_std=0.01, contain_method='occnets')])

In [2]:
import os

global PROCESSED_MODELS
all_files = os.listdir("/ibex/project/c2273/3DCoMPaT/packaged/")
PROCESSED_MODELS = set(filename[:6] for filename in all_files)

len(PROCESSED_MODELS)

7968

In [7]:
len([f for f in all_files if "_near_surface_points" in f])

260254

In [4]:
shape_count = 260254
batch_size = 128
batch_process_time = 120
n_gpus = 8


total_time = shape_count * batch_process_time / batch_size / n_gpus / 3600

total_time

8.471809895833333