In [1]:
%cd /home/slimhy/Documents/PADS/code
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

/home/slimhy/Documents/PADS/code
Set seed to 0
Loading autoencoder ckpt/ae_m512.pth


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader
from datasets.metadata import class_to_hex

latents_dir =  "/home/slimhy/Documents/datasets/PADS/3DCoMPaT"

# Create your dataset
dataset = ShapeLatentDataset(latents_dir, shuffle_parts=False, class_code=class_to_hex("chair"))

# Create the DataLoader using the sampler
dataloader = ComposedPairedShapesLoader(
    dataset,
    batch_size=32,
    pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
    num_workers=0,
    shuffle=False,
) 

# Use the dataloader in your training loop
k_break = 1
k = 0
for (latents, part_bbs, part_labels, meta_A), (
    latent_B,
    bb_coords_B,
    bb_labels_B,
    meta_B, 
) in dataloader:
    k += 1
    if k == k_break:
        break

In [3]:
latent_0, part_bbs_0 = latents[0], part_bbs[0]
latent_0, part_bbs_0 = latent_0.to(device), part_bbs_0.to(device)

In [4]:
from util import s2vs

# Decode latent to mesh using s2vs
mesh = s2vs.decode_latents(ae, latent_0.unsqueeze(0), grid_density=256, batch_size=64**3)

In [5]:
import numpy as np
import trimesh

def create_trimesh_boxes(bounding_boxes):
    meshes = []
    for box in bounding_boxes:
        # Calculate the center of the bounding box
        center = np.mean(box, axis=0)
        
        # Calculate the dimensions of the bounding box
        dimensions = np.ptp(box, axis=0)
        
        # Create a box primitive using trimesh
        mesh = trimesh.creation.box(extents=dimensions)
        
        # Move the box to the correct position
        mesh.apply_translation(center)
        
        meshes.append(mesh)
    
    return meshes

bounding_boxes = np.array(part_bbs_0.cpu())

# Create the list of trimesh meshes
trimesh_boxes = create_trimesh_boxes(bounding_boxes)


In [6]:
import k3d

# Use k3d colormaps
unique_parts = np.array(range(len(trimesh_boxes)))
col_map = k3d.helpers.map_colors(unique_parts, k3d.colormaps.basic_color_maps.Rainbow)
col_map = [int(c) for c in col_map]

# Create the plot
plot = k3d.plot()

all_bbs = []
plot += k3d.mesh(np.array(mesh.vertices.cpu().numpy()), np.array(mesh.faces.cpu().numpy()), color=0xefefef)
for k, bb_mesh in enumerate(trimesh_boxes):
    # Set color with low alpha
    plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=col_map[k], opacity=0.5)
    
plot



Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, axes_helper_colors=[16711680, 65280, 255], background…

In [17]:
mesh.trimesh_mesh.bounding_box_oriented.vertices

TrackedArray([[ 0.37249579, -0.49483607,  0.13550597],
              [ 0.36855525,  0.47969185,  0.15851934],
              [-0.11567337, -0.49682651,  0.13620537],
              [-0.11961391,  0.4777014 ,  0.15921874],
              [ 0.37196872, -0.48669762, -0.20921628],
              [ 0.36802819,  0.48783029, -0.18620291],
              [-0.11620044, -0.48868807, -0.20851688],
              [-0.12014097,  0.48583985, -0.18550351]])