In [1]:
%cd /ibex/user/slimhy/PADS/code
import argparse
import numpy as np
import torch
import trimesh
import k3d

import util.misc as misc
import models.s2vs as ae_mods
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader
from models.partvae import PartAwareVAE
from util import s2vs

def get_args_parser():
    parser = argparse.ArgumentParser("Autoencoder Visualization", add_help=False)
    parser.add_argument("--ae_pth", required=True, help="Autoencoder checkpoint")
    parser.add_argument("--ae", type=str, default="kl_d512_m512_l8", help="Name of autoencoder")
    parser.add_argument("--ae-latent-dim", type=int, default=4096, help="AE latent dimension")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size")
    parser.add_argument("--num_workers", default=8, type=int, help="Number of workers for data loading")
    parser.add_argument("--device", default="cuda", help="Device to use for computation")
    parser.add_argument("--seed", default=0, type=int, help="Random seed")
    parser.add_argument("--pin_mem", action="store_true", help="Pin CPU memory in DataLoader")
    return parser

def obb_to_corners(box_array):
    center, right, up, forward = [np.array(box_array[i]) for i in range(4)]
    corners = np.array([
        [-1, -1, -1], [ 1, -1, -1], [ 1,  1, -1], [-1,  1, -1],
        [-1, -1,  1], [ 1, -1,  1], [ 1,  1,  1], [-1,  1,  1]
    ])
    transform = np.column_stack((right, up, forward))
    return np.dot(corners, transform.T) + center

def create_trimesh_boxes(bounding_boxes):
    return [
        trimesh.util.concatenate([
            trimesh.primitives.Sphere(radius=0.01).apply_translation(corner)
            for corner in obb_to_corners(box)
        ])
        for box in bounding_boxes if not np.all(box == 0)
    ]

def load_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    return model

def visualize_decoded_mesh(decoded_mesh, bounding_boxes):
    trimesh_boxes = create_trimesh_boxes(bounding_boxes)
    col_map = k3d.helpers.map_colors(np.arange(len(trimesh_boxes)), k3d.colormaps.basic_color_maps.Rainbow)
    
    plot = k3d.plot()
    plot += k3d.mesh(decoded_mesh.vertices, decoded_mesh.faces, color=0x0000ff, opacity=0.5)
    for k, bb_mesh in enumerate(trimesh_boxes):
        plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=int(col_map[k]), opacity=0.5)
    return plot

# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda \
    --seed 0 \
    --pin_mem"""

args = get_args_parser().parse_args(call_string.split())

In [4]:
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Initialize and load autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = torch.compile(ae.eval().to(device), mode="max-autotune")

# Initialize PartAwareVAE
model = PartAwareVAE(
    dim=512,
    latent_dim=128, # Replace with args.part_latents_dim if available
    heads=8,
    dim_head=64,
    depth=12, # Replace with args.layer_depth if available
).to(device)
model = load_checkpoint(
    model,
    '/ibex/user/slimhy/PADS/output/partvae/partvae_partvae__kl_1e5__rec_08__inv_02__schedulefree/checkpoint-50.pth'
)
model.eval()

# Create dataset and dataloader
dataset_val = ShapeLatentDataset(
    "/ibex/project/c2273/PADS/3DCoMPaT/",
    split="test",
    shuffle_parts=False
)
pair_types = ['part_drop,orig']
data_loader_val = ComposedPairedShapesLoader(
    dataset_val,
    batch_size=args.batch_size,
    pair_types_list=pair_types,
    num_workers=args.num_workers,
    seed=args.seed,
    shuffle=False,
    pin_memory=args.pin_mem,
    drop_last=True,
)

# Get a single sample
for data_tuple in data_loader_val:
    pair_types, (l_a, bb_a, bb_l_a, _), _ = data_tuple
    break

# Process the sample
l_a, bb_a, bb_l_a = l_a.to(device), bb_a.to(device), bb_l_a.to(device)
mask_a = (bb_l_a != -1).to(device)


In [24]:
obj_k = 5

# Decode using the autoencoder
with torch.no_grad():
    if hasattr(model, 'is_vae') and model.is_vae:
        logits_a, kl_a, part_latents_a = model(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a)
    else:
        logits_a, part_latents_a = model(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a)
    logits_in = logits_a[obj_k].unsqueeze(0)
    mesh = s2vs.decode_latents(ae, logits_in, grid_density=256, batch_size=64**3)

print("Input latents shape:", l_a.shape)
print("Decoded latents shape:", logits_a.shape)
print("Part latents shape:", part_latents_a.shape)

# Visualize the decoded mesh with bounding boxes
bounding_boxes = np.array(bb_a[obj_k].cpu())
mesh = s2vs.decode_latents(ae, logits_in, grid_density=256, batch_size=64**3)
plot = visualize_decoded_mesh(mesh.trimesh_mesh, bounding_boxes)
plot.display()