In [1]:
%cd /ibex/user/slimhy/PADS/code/
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch
import trimesh

import util.misc as misc
import util.s2vs as s2vs

from datasets.shapeloaders import CoMPaTManifoldDataset, ShapeNetDataset
from datasets.metadata import SHAPENET_NAME_TO_SYNSET, SHAPENET_NAME_TO_SYNSET_INDEX
from inversion.evaluate import evaluate_reconstruction
from util.misc import d_GPU, show_side_by_side

/ibex/user/slimhy/PADS/code


In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [3]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [4]:
# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

Set seed to 0
Loading autoencoder [ckpt/ae_m512.pth].


In [254]:
ACTIVE_CLASS = "vase"
N_POINTS = 2**21
MAX_POINTS = 2**21 # Maximum number of points for a single batch on a A100 GPU

In [255]:
import os
import numpy as np
import torch
from torch.utils import data
from datasets.metadata import (
    SHAPENET_CLASSES,
    get_compat_transform,
    get_shapenet_transform,
)


class ShapeNetDataset(data.Dataset):
    """
    Sampling from a ShapeNet dataset.
    """

    def __init__(
        self,
        dataset_folder,
        shape_cls=None,
        transform=None,
        sampling=True,
        num_samples=4096,
        pc_size=2048,
        replica=16,
    ):
        self.pc_size = pc_size

        self.transform = transform
        self.num_samples = num_samples
        self.sampling = sampling

        self.return_surface = True
        self.surface_sampling = True

        self.point_folder = dataset_folder
        self.dataset_folder = os.path.join(dataset_folder, shape_cls)

        self.models = []
        subpath = os.path.join(self.dataset_folder, "4_pointcloud")

        self.models += [
            {"category": shape_cls, "model": m.replace(".npz", "")} for m in os.listdir(subpath)
        ]

        self.replica = replica

    def __getitem__(self, idx):
        idx = idx % len(self.models)

        category = self.models[idx]["category"]
        model = self.models[idx]["model"]

        pc_path = os.path.join(
            self.dataset_folder, "4_pointcloud", model + ".npz"
        )
        with np.load(pc_path) as data:
            surface = data["points"].astype(np.float32)
        if self.surface_sampling:
            ind = np.random.default_rng().choice(
                surface.shape[0], self.pc_size, replace=False
            )
            surface = surface[ind]
        surface = torch.from_numpy(surface).unsqueeze(0)

        return surface, SHAPENET_CLASSES[category]

    def __len__(self):
        if self.split != "train":
            return len(self.models)
        else:
            return len(self.models) * self.replica


In [256]:

def flip_front_to_back(pc):
    """
    Rotate 180° around Y axis (from front-facing to back-facing).
    """
    full_transform = torch.tensor(
        [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
        dtype=torch.float32,
        device=pc.device,
    )
    return torch.matmul(pc.squeeze(), full_transform).unsqueeze(0)


def flip_front_to_right(pc):
    """
    Rotate 90° around Y axis (from front-facing to right-facing).
    """
    full_transform = torch.tensor(
        [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
        dtype=torch.float32,
        device=pc.device,
    )
    return torch.matmul(pc.squeeze(), full_transform).unsqueeze(0)


def flip_front_to_left(pc):
    """
    Rotate 90° around Y axis (from front-facing to left-facing).
    """
    full_transform = torch.tensor(
        [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]],
        dtype=torch.float32,
        device=pc.device,
    )
    return torch.matmul(pc.squeeze(), full_transform).unsqueeze(0)


def flip_left_to_right(pc):
    """
    Rotate 180° around Y axis (from left-facing to right-facing).
    """
    full_transform = torch.tensor(
        [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
        dtype=torch.float32,
        device=pc.device,
    )
    return torch.matmul(pc.squeeze(), full_transform).unsqueeze(0)


# Map compat classes to transformations
COMPAT_TRANSFORMS = {
    "airplane": flip_front_to_right,
    "bag": flip_front_to_back,
    "basket": flip_front_to_left,
    "bed": flip_front_to_back,
    "bench": flip_front_to_back,
    "bird_house": flip_front_to_back,
    "boat": flip_front_to_right,
    "cabinet": flip_front_to_left,
    "dresser": flip_front_to_right,
    "car": flip_front_to_right,
    "chair": flip_front_to_back,
    "dishwasher": flip_front_to_back,
    "faucet": flip_front_to_back,
    "jug": flip_front_to_right,
    "lamp": flip_front_to_right,
    "love_seat": flip_front_to_back,
    "shelf": flip_front_to_right,
    "skateboard": flip_front_to_right,
    "sofa": flip_front_to_back,
    "sports_table": flip_front_to_right,
    "table": flip_front_to_right,
    "trashcan": flip_front_to_right,
}

SHAPENET_TRANSFORMS = {
    "airplane": lambda x: x,
    "bag": flip_front_to_back,
    "basket": flip_front_to_left,
    "bed": lambda x: x,
    "bench": lambda x: x,
    "bird_house": lambda x: x,
    "cabinet": flip_front_to_left,
    "dresser": flip_front_to_left,
    "car": lambda x: x,
    "chair": lambda x: x,
    "dishwasher": lambda x: x,
    "jar": lambda x: x,
    "lamp": lambda x: x,
    "love_seat": lambda x: x,
    "ottoman": lambda x: x,
    "shelf": flip_front_to_left,
    "skateboard": lambda x: x,
    "sofa": lambda x: x,
    "table": flip_front_to_right,
    "trashcan": flip_front_to_left,
    "vase": lambda x: x,
}


def get_compat_transform(class_name):
    """
    Get the transformation function for a class.
    """
    return COMPAT_TRANSFORMS.get(class_name, lambda x: x)


def get_shapenet_transform(class_name):
    """
    Get the transformation function for a class.
    """
    return SHAPENET_TRANSFORMS.get(class_name, lambda x: x)


In [257]:
OBJ_DIR = "/ibex/project/c2273/ShapeNet/"
OBJ_ID = 10

# Initialize the latents
orig_dataset = ShapeNetDataset(
    dataset_folder=OBJ_DIR,
    shape_cls=SHAPENET_NAME_TO_SYNSET[ACTIVE_CLASS],
    pc_size=ae.num_inputs,
    replica=1,
)
surface_points, _ = orig_dataset[OBJ_ID]
surface_points = surface_points.to(device)
init_latents = s2vs.encode_pc(ae, surface_points).detach()

In [258]:
def center_mesh(mesh):
    """
    Center the mesh.
    """
    mesh.vertices -= mesh.centroid
    return mesh

# Load shape
def get_gen(class_name):
    class_id = SHAPENET_NAME_TO_SYNSET_INDEX[class_name]
    shape_path = "/ibex/user/slimhy/3DShape2VecSet/class_cond_obj/kl_d512_m512_l8_d24_edm/"
    return center_mesh(trimesh.load(shape_path + "%02d-00000.obj" % class_id))

In [259]:
surface_points = surface_points.to(device)
surface_points = get_shapenet_transform(ACTIVE_CLASS)(surface_points)
init_latents = s2vs.encode_pc(ae, surface_points)
rec_mesh = s2vs.decode_latents(ae, d_GPU(init_latents), grid_density=256, batch_size=128**3)

shapenet_gen = get_gen(ACTIVE_CLASS)
show_side_by_side(rec_mesh, shapenet_gen, flip_front_to_back=False)

In [260]:
# %cd /ibex/user/slimhy/3DShape2VecSet/
# ! python sample_class_cond.py \
#     --ae kl_d512_m512_l8 \
#     --ae-pth output/ae/kl_d512_m512_l8/checkpoint-199.pth \
#     --dm kl_d512_m512_l8_d24_edm \
#     --dm-pth output/class_cond_dm/kl_d512_m512_l8_d24_edm/checkpoint-499.pth
# 