In [None]:
%cd /ibex/user/slimhy/PADS/code/
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch
import trimesh

import util.misc as misc
import util.s2vs as s2vs

from datasets.shapeloaders import CoMPaTManifoldDataset, ShapeNetDataset
from datasets.metadata import SHAPENET_NAME_TO_SYNSET, SHAPENET_NAME_TO_SYNSET_INDEX
from util.misc import d_GPU, show_side_by_side

In [None]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [None]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [None]:
# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

In [None]:
ACTIVE_CLASS = "lamp"
N_POINTS = 2**21
MAX_POINTS = 2**21 # Maximum number of points for a single batch on a A100 GPU

In [None]:
import numpy as np
import torch
from datasets.metadata import (
    get_compat_transform,
    get_shapenet_transform,
)

In [None]:
OBJ_DIR = "/ibex/project/c2273/ShapeNet/"
OBJ_ID = 7

# Initialize the latents
orig_dataset = ShapeNetDataset(
    dataset_folder=OBJ_DIR,
    shape_cls=SHAPENET_NAME_TO_SYNSET[ACTIVE_CLASS],
    pc_size=ae.num_inputs,
    replica=1,
)
surface_points, _ = orig_dataset[OBJ_ID]
surface_points = surface_points.to(device)
init_latents = s2vs.encode_pc(ae, surface_points).detach()

In [None]:
OBJ_DIR = "/ibex/user/slimhy/PADS/data/obj_manifold/"

def flip_front_to_right(pc):
    """
    Rotate 90° around Y axis (from front-facing to right-facing).
    """
    full_transform = torch.tensor(
        [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
        dtype=torch.float32,
        device=pc.device,
    )
    return torch.matmul(pc.squeeze(), full_transform).unsqueeze(0)

# Initialize the latents
alt_dataset = CoMPaTManifoldDataset(
    OBJ_DIR,
    ACTIVE_CLASS,
    ae.num_inputs,
    normalize=False,
    sampling_method="surface",
    scale_to_shapenet=False,
)
surface_points_alt, _ = next(alt_dataset[OBJ_ID])
surface_points_alt = get_compat_transform(ACTIVE_CLASS)(surface_points_alt)

In [None]:
def center_mesh(mesh):
    """
    Center the mesh.
    """
    mesh.vertices -= mesh.centroid
    return mesh
    
# Load shape
def get_gen(class_name):
    class_id = SHAPENET_NAME_TO_SYNSET_INDEX[class_name]
    shape_path = "/ibex/user/slimhy/3DShape2VecSet/class_cond_obj/kl_d512_m512_l8_d24_edm/"
    return center_mesh(trimesh.load(shape_path + "%02d-00000.obj" % class_id))

In [None]:
init_latents = s2vs.encode_pc(ae, surface_points)
surface_points = surface_points.to(device)
surface_points_alt = get_shapenet_transform(ACTIVE_CLASS)(surface_points_alt)
rec_mesh = s2vs.decode_latents(ae, d_GPU(init_latents), grid_density=256, batch_size=128**3)

# Decode the optimized latents
if surface_points_alt is not None:
    init_latents = s2vs.encode_pc(ae, surface_points_alt)
    rec_mesh_compat = s2vs.decode_latents(ae, d_GPU(init_latents), grid_density=256, batch_size=128**3)
else:
    rec_mesh_compat = trimesh.Trimesh(vertices=[], faces=[])

shapenet_gen = get_gen(ACTIVE_CLASS)
show_side_by_side(rec_mesh, shapenet_gen, rec_mesh_compat, flip_front_to_back=False)

In [None]:
def get_stuff(surface_points):
    # Get pc bounding box following each axis
    min_x = surface_points[:, :, 0].min()
    max_x = surface_points[:, :, 0].max()
    min_y = surface_points[:, :, 1].min()
    max_y = surface_points[:, :, 1].max()
    min_z = surface_points[:, :, 2].min()
    max_z = surface_points[:, :, 2].max()

    print(min_x.item(), max_x.item())
    print(min_y.item(), max_y.item())
    print(min_z.item(), max_z.item())

    # Compute extents
    extents = [d.item() for d in [max_x - min_x, max_y - min_y, max_z - min_z]]
    print(extents)


get_stuff(surface_points_alt)
print()
get_stuff(surface_points)
print()
# sample from shapenet_gen, add extra dim 0
sampled_points = np.array(shapenet_gen.sample(N_POINTS))[np.newaxis, :, :]
get_stuff(sampled_points)

In [None]:
# %cd /ibex/user/slimhy/3DShape2VecSet/
# ! python sample_class_cond.py \
#     --ae kl_d512_m512_l8 \
#     --ae-pth output/ae/kl_d512_m512_l8/checkpoint-199.pth \
#     --dm kl_d512_m512_l8_d24_edm \
#     --dm-pth output/class_cond_dm/kl_d512_m512_l8_d24_edm/checkpoint-499.pth
# 