In [1]:
%cd /ibex/user/slimhy/PADS/code/
"""
Invert a set of input shapes.
"""
import argparse
import torch
import os
import numpy as np

import util.misc as misc
import util.s2vs as s2vs

from datasets.sampling import sample_surface_tpp
from eval.metrics import chamfer_distance, f_score


def get_args_parser():
    """
    Parsing input arguments.
    """
    parser = argparse.ArgumentParser("Extracting Latents", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        type=int,
        help="Batch size per GPU (this is the grid dimension, to be cubed)",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae_latent_dim",
        type=int,
        default=512 * 8,
        help="AE latent dimension",
    )
    parser.add_argument("--ae_pth", required=True, help="Autoencoder checkpoint")

    # CUDA parameters
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    # Datasplit parameters
    parser.add_argument("--obj_dir", type=str, help="Path to object directory")
    parser.add_argument(
        "--opt_point_ratio",
        type=int,
        help="Ratio of points to sample from the object surface (2^N)",
        default=1,
    )

    # Hyperparameters
    parser.add_argument(
        "--lr",
        type=float,
        help="Learning rate",
    )
    parser.add_argument(
        "--max_iter",
        type=int,
        help="Maximum number of iterations for each stage",
    )

    # Logging
    parser.add_argument(
        "--log_dir",
        type=str,
        help="Logging directory",
    )
    parser.add_argument(
        "--config_name",
        type=str,
        help="Name of the active configuration",
    )    
    
    # Distribution parameters
    parser.add_argument("--process_id", type=int, help="ID of the current process")
    parser.add_argument("--max_process", type=int, help="Total number of processes")


    return parser


def initialize_latents(args, ae, root_names):
    """
    Get the initial latents for the optimization.
    Also load the points and occs for the input shapes.
    """
    device = ae.device
    
    def load_batch(root_name, suffix):
        """
        Load a batch of points and occs for a root name and suffix.
        """
        stacked_data = []
        for root_name in root_names:
            path = os.path.join(args.obj_dir, root_name + suffix + ".npy")
            stacked_data += [np.load(path)]
        return torch.tensor(np.array(stacked_data), device=device).squeeze(1)

    # Make a batch from the points in CPU
    points = load_batch(root_names, "_surface_points")
    near_points = load_batch(root_names, "_near_surface_points")
    occs = load_batch(root_names, "_occs")

    # Encode the points
    init_latents = s2vs.encode_pc(ae, points, fps_sampling=True).detach()

    return init_latents, points, near_points, occs


def get_metrics(ae, pc_original, latents, batch_size):
    """
    Get the metrics for an input latent.
    """
    rec_mesh = s2vs.decode_latents(
        ae=ae,
        latent=misc.d_GPU(latents),
        grid_density=512,
        batch_size=batch_size,
    )
    pc_rec = sample_surface_tpp(rec_mesh, pc_original.shape[0])
    chamfer = chamfer_distance(pc_original, pc_rec, backend="kaolin")
    f_sc = f_score(gt=pc_original, pred=pc_rec)
    return rec_mesh, chamfer, f_sc

/ibex/user/slimhy/PADS/code


In [2]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae_latent_dim 4096 \
    --batch_size 128 \
    --num_workers 8 \
    --device cuda
    --obj_dir /ibex/project/c2273/3DCoMPaT/packaged \
    --process_id 0
    --max_process 1
    """

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

Set seed to 0
Loading autoencoder [ckpt/ae_m512.pth].


In [3]:
from datasets.metadata import class_to_hex
def get_root_name(f, no_rot):
    f = f.replace("_near_surface", "")
    f = f.replace("_surface", "")
    r = f.split("_")[-1]
    if no_rot:
        if "_all_aug_" in f or "_rand_rot_" in f:
            return None
    return f.replace("_" + r, "")


def get_all_files(path, class_name=None, no_rot=False):
    """
    List and filter augmented shapes from the input directory.
    """
    root_names = os.listdir(path)
    root_names = [get_root_name(f, no_rot=no_rot) for f in root_names]
    root_names = [f for f in root_names if f is not None]

    if class_name is not None:
        class_hex = class_to_hex(class_name) + "_"
        root_names = [f for f in root_names if f.startswith(class_hex)]

    return root_names


def batch_list(root_names, batch_size):
    """
    Batch the root names.
    """
    return [
        root_names[i : i + batch_size] for i in range(0, len(root_names), batch_size)
    ]

In [4]:
def save_latents(out_dir, root_batch, latents):
    """
    Save the latents to a file.
    """
    for i, root in enumerate(root_batch):
        path = os.path.join(out_dir, root + ".npy")
        np.save(path, latents[i].cpu().numpy())
        
IN_PATH = "/ibex/project/c2273/3DCoMPaT/packaged"
OUT_PATH = "/ibex/project/c2273/3DCoMPaT/latents"

In [5]:
# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

all_metrics = {}
root_names = get_all_files(args.obj_dir)
root_names = batch_list(root_names, args.batch_size)
for root_batch in root_names:
    # Initialize the latents, define the ground truth mesh
    init_latents, points, near_points, occs = initialize_latents(args, ae, root_batch)
    
    # Save the initial latents
    init_latents = init_latents.clone().detach()

    # save_latents(OUT_PATH, root_batch, init_latents)
    
    break

Set seed to 0
Loading autoencoder [ckpt/ae_m512.pth].
