In [1]:
%cd /ibex/user/slimhy/PADS/code
from datasets.shapeloaders import CoMPaTSegmentDataset
import numpy as np
import torch
import sys
from util.misc import dump_pickle

In [2]:
N_POINTS_PER_SHAPE = 2**18
OUT_PATH = "/tmp/some_stuff/"
SAMPLES_PER_DATASET = 8
PROCESSED_MODELS = set()


def get_datasets(active_class):
    compat_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=False,
        remove_small_parts=False,
        **{
            "n_points": N_POINTS_PER_SHAPE,
            "sampling_method": "surface+near_surface",
            "near_surface_noise": 0.01,
        }
    )

    compat_part_drop_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        shape_cls=active_class,
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=True,
        random_part_drop=True,
        n_parts_to_drop=1,
        remove_small_parts=False,
        **{
            "n_points": N_POINTS_PER_SHAPE,
            "sampling_method": "surface+near_surface",
            "near_surface_noise": 0.01,
        }
    )
    
    return {
        "orig": compat_dataset,
        "part_drop": compat_part_drop_dataset,
    }



def is_model_processed(model_id):
    """
    Check if the given model_id exists in the set of processed models.
    """
    global PROCESSED_MODELS
    return model_id in PROCESSED_MODELS

In [3]:
def export_dataset_entry(datasets, obj_k, out_path):
    global PROCESSED_MODELS
    model_id = datasets["orig"].get_model_id(obj_k)

    # Check if the model has already been processed
    if is_model_processed(model_id):
        print(f"Skipping model {model_id} as it has already been processed")
        return

    for dset_name, dataset in datasets.items():
        for sample_id in range(SAMPLES_PER_DATASET):
            sys.stdout.flush()

            all_points, occs, bbs = next(dataset[obj_k])
            surface_points, near_surface_points = (
                all_points[0],
                all_points[1],
            )
            assert (
                torch.sum(occs[: N_POINTS_PER_SHAPE // 2].flatten())
                == N_POINTS_PER_SHAPE // 2
            )

            occs = occs[N_POINTS_PER_SHAPE // 2 :]

            sample_code = f"{model_id}_{dset_name}_{sample_id}"

            # Store the points
            np.save(
                f"{out_path}/{sample_code}_surface_points", surface_points.cpu().numpy()
            )
            np.save(
                f"{out_path}/{sample_code}_near_surface_points",
                near_surface_points.cpu().numpy(),
            )

            # Store the occupancy grid
            np.save(f"{out_path}/{sample_code}_occs", occs.cpu().numpy())

            # Store the transformation matrix
            np.save(f"{out_path}/{sample_code}_transformation", dataset.transform_mat)

            # Store the bounding boxes
            dump_pickle(bbs, f"{out_path}/{sample_code}_bbs.pkl")

            if dset_name == "orig":
                break

    # Add the processed model to the set
    PROCESSED_MODELS.add(model_id)
    print(f"Processed model {model_id}")


In [4]:
dsets = get_datasets("chair")

In [7]:
from datasets.metadata import COMPAT_MATCHED_CLASSES

export_dataset_entry(dsets, 0, OUT_PATH)

In [None]:
def generate_colormap_colors(n, colormap_name="viridis", alpha=1.0):
    colormap = cm.get_cmap(colormap_name)
    colors = [(*colormap(i / (n - 1))[:3], alpha) for i in range(n)]
    return colors

def visualize_bounding_boxes(
    bounding_boxes,
    mesh=None,
    box_type="spheres",
    line_radius=0.005,
    colormap="viridis",
    alpha=1.0,
):
    scene = trimesh.Scene()
    if mesh is not None:
        scene.add_geometry(mesh)
    
    colors = generate_colormap_colors(len(bounding_boxes), colormap, alpha)
    
    for bb, color in zip(bounding_boxes, colors):
        corners = bb.get_corners()
        if box_type == "spheres":
            for corner in corners:
                sphere = trimesh.creation.uv_sphere(radius=line_radius)
                sphere.apply_translation(corner)
                sphere.visual.face_colors = np.array(color) * 255
                scene.add_geometry(sphere)
                
        edges = bb.get_edges()
        for edge in edges:
            cylinder = trimesh.creation.cylinder(
                radius=line_radius,
                segment=edge,
            )
            cylinder.visual.face_colors = np.array(color) * 255
            scene.add_geometry(cylinder)
    
    return scene

def visualize_dataset(out_path):
    # Get all surface point files
    surface_files = glob.glob(os.path.join(out_path, "*_surface_points.npy"))
    
    for sf in surface_files:
        base_name = sf.replace("_surface_points.npy", "")
        
        # Load data
        surface_points = np.load(sf)
        occs = np.load(f"{base_name}_occs.npy")
        with open(f"{base_name}_bbs.pkl", 'rb') as f:
            bbs = pickle.load(f)
            
        # Visualize surface points
        points_scene = visualize_pointcloud([surface_points])
        points_scene.show()
        
        # Visualize occupancy grid as points
        occ_positions = np.where(occs > 0.5)
        occ_points = np.stack(occ_positions, axis=1)
        occ_scene = visualize_pointcloud([occ_points])
        occ_scene.show()
        
        # Visualize bounding boxes
        bb_scene = visualize_bounding_boxes(bbs)
        bb_scene.show()
