In [1]:
%cd /ibex/user/slimhy/PADS/code/
# %env CUDA_LAUNCH_BLOCKING=1
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

from datasets.shapeloaders import CoMPaTSegmentDataset, ShapeNetDataset
from datasets.metadata import (
    SHAPENET_NAME_TO_SYNSET_INDEX,
    SHAPENET_NAME_TO_SYNSET
)

import util.misc as misc
import util.s2vs as s2vs
import numpy as np
import trimesh
from util.misc import d_GPU, 

/ibex/user/slimhy/PADS/code


ImportError: cannot import name 'show_side_by_side' from 'util.misc' (/ibex/user/slimhy/PADS/code/util/misc.py)

In [None]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

In [None]:
def get_rec_mesh(ae, points, device):
    points = points.to(device)
    init_latents = s2vs.encode_pc(ae, points)
    return s2vs.decode_latents(ae, d_GPU(init_latents), grid_density=256, batch_size=128**3)

def get_datasets(active_class):
    shapenet_dataset = ShapeNetDataset(
        dataset_folder="/ibex/project/c2273/ShapeNet/",
        shape_cls=SHAPENET_NAME_TO_SYNSET[active_class],
        pc_size=2048,
    )

    compat_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        active_class,
        2048,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=True,
        force_retransform=True,
    )

    compat_transformed_dataset = CoMPaTSegmentDataset(
        "/ibex/project/c2273/3DCoMPaT/manifold_part_instances/",
        active_class,
        2048,
        sampling_method="surface",
        recenter_mesh=True,
        process_mesh=True,
        scale_to_shapenet=True,
        align_to_shapenet=True,
        random_transform=False,
        force_retransform=True,
        random_part_drop=True,
        n_parts_to_drop=1,
        remove_small_parts=True,
        min_part_volume=0.005**3
    )
    
    return shapenet_dataset, compat_dataset, compat_transformed_dataset


def get_points(dataset, transform=None, is_compat=False, is_seg=False, obj_k=0):
    if is_compat:
        if is_seg:
            surface_points, _, bbs = next(dataset[obj_k])
            return surface_points, bbs
        else:
            surface_points, _ = next(dataset[obj_k])
        return surface_points
    else:
        surface_points, _ = dataset[obj_k]
        return transform(surface_points)


def center_mesh(mesh):
    """
    Center the mesh.
    """
    # Recenter the trimesh mesh
    center = mesh.bounding_box.centroid
    translation_mat = trimesh.transformations.translation_matrix(-center)

    # Apply the transformation
    return mesh.apply_transform(translation_mat)

# Load shape
def get_gen(class_name):
    class_id = SHAPENET_NAME_TO_SYNSET_INDEX[class_name]
    shape_path = "/ibex/user/slimhy/3DShape2VecSet/class_cond_obj/kl_d512_m512_l8_d24_edm/"
    return center_mesh(trimesh.load(shape_path + "%02d-00000.obj" % class_id))

In [None]:
MESH_ID = 5
ACTIVE_CLASS = "chair"
shapenet_dataset, compat_dataset, compat_transformed_dataset = get_datasets(ACTIVE_CLASS)

In [None]:
compat_points, bbs = get_points(compat_dataset, is_compat=True, is_seg=True, obj_k=MESH_ID)
compat_transformed_points, bbs = get_points(compat_transformed_dataset, is_compat=True, is_seg=True, obj_k=MESH_ID)

compat_mesh = get_rec_mesh(ae, compat_points, device)
compat_transformed_mesh = get_rec_mesh(ae, compat_transformed_points, device)

show_side_by_side(compat_mesh, compat_transformed_mesh)

In [None]:
import k3d

def plot_mesh_bbs(mesh, bbs):
    # Use k3d colormaps
    unique_parts = np.array(range(len(bbs)))
    col_map = k3d.helpers.map_colors(unique_parts, k3d.colormaps.basic_color_maps.Rainbow)
    col_map = [int(c) for c in col_map]

    # Create the plot
    plot = k3d.plot()

    plot += k3d.mesh(np.array(mesh.vertices.cpu().numpy()), np.array(mesh.faces.cpu().numpy()), color=0xefefef)
    plot += k3d.mesh(mesh.trimesh_mesh.bounding_box_oriented.vertices, mesh.trimesh_mesh.bounding_box_oriented.faces, color=0xefefef, opacity=0.1)
    for k, bb in enumerate(bbs):
        # if k != 0: continue
        bb_mesh = bb[1]
        # Set color with low alpha
        plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=col_map[k], opacity=0.5)
        
    return plot

In [None]:
plot_mesh_bbs(compat_transformed_mesh, bbs)