In [None]:
%cd /ibex/user/slimhy/PADS/code
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

In [2]:
def obb_to_corners(box_array):
    # Ensure inputs are numpy arrays
    center = np.array(box_array[0])
    right = np.array(box_array[1])
    up = np.array(box_array[2])
    forward = np.array(box_array[3])

    # Define the 8 corners relative to the center
    corners = np.array([
        [-1, -1, -1],
        [ 1, -1, -1],
        [ 1,  1, -1],
        [-1,  1, -1],
        [-1, -1,  1],
        [ 1, -1,  1],
        [ 1,  1,  1],
        [-1,  1,  1]
    ])

    # Create transformation matrix
    transform = np.column_stack((right, up, forward))

    # Apply transformation and add center offset
    transformed_corners = np.dot(corners, transform.T) + center

    return transformed_corners

In [3]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader
from datasets.metadata import class_to_hex

latents_dir =  "/ibex/project/c2273/PADS/3DCoMPaT"

# Create your dataset
dataset = ShapeLatentDataset(latents_dir, shuffle_parts=False, class_code=class_to_hex("chair"))

# Create the DataLoader using the sampler
dataloader = ComposedPairedShapesLoader(
    dataset,
    batch_size=32,
    pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
    num_workers=0,
    shuffle=False,
) 

# Use the dataloader in your training loop
k_break = 1
k = 0
for pair_types, (latents, part_bbs, part_labels, _, meta_A), (
    latent_B,
    bb_coords_B,
    bb_labels_B,
    _,
    meta_B, 
) in dataloader:
    k += 1
    if k == k_break:
        break

In [4]:
latent_0, part_bbs_0 = latents[0], part_bbs[0]
latent_0, part_bbs_0 = latent_0.to(device), part_bbs_0.to(device)

In [5]:
from util import s2vs

# Decode latent to mesh using s2vs
# shuffle the latents
latent_0 = latent_0[torch.randperm(latent_0.size(0))].clone()
new_latents = torch.randn_like(latent_0)
for k in range(8):
    new_latents[:, k] = latent_0[:, k]
mesh = s2vs.decode_latents(ae, new_latents.unsqueeze(0), grid_density=256, batch_size=64**3)

In [6]:
import numpy as np
import trimesh

def create_trimesh_boxes(bounding_boxes):
    meshes = []
    for box in bounding_boxes:
        # just plot a sphere for each corner
        sub_meshes = []
        # if all 0 skip
        if np.all(box == 0):
            continue
        corners = obb_to_corners(box)
        for corner in corners:
            sub_meshes += [trimesh.primitives.Sphere(radius=0.01).apply_translation(corner)]
        mesh = trimesh.util.concatenate(sub_meshes)
        # mesh.visual.face_colors = [0, 0, 255, 100]
        meshes.append(mesh)
    
    return meshes

bounding_boxes = np.array(part_bbs_0.cpu())

# Create the list of trimesh meshes
trimesh_boxes = create_trimesh_boxes(bounding_boxes)


In [None]:
import k3d

# Use k3d colormaps
unique_parts = np.array(range(len(trimesh_boxes)))
col_map = k3d.helpers.map_colors(unique_parts, k3d.colormaps.basic_color_maps.Rainbow)
col_map = [int(c) for c in col_map]

# Create the plot
plot = k3d.plot()

all_bbs = []
plot += k3d.mesh(np.array(mesh.vertices.cpu().numpy()), np.array(mesh.faces.cpu().numpy()), color=0xefefef, opacity=0.5)
for k, bb_mesh in enumerate(trimesh_boxes):
    # Set color with low alpha
    plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=col_map[k], opacity=0.5)

plot