In [1]:
%cd /ibex/user/slimhy/PADS/code/

/ibex/user/slimhy/PADS/code


In [2]:
# Add "3DCoMPaT" to the path
import sys
sys.path.append("jupyter/3DCoMPaT/")
from compat3D import ShapeLoader

In [3]:
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import numpy as np
import torch

import util.misc as misc
import models.autoencoders as ae_mods

In [4]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [5]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [6]:
# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

Set seed to 0
Loading autoencoder ckpt/ae_m512.pth


In [7]:
import mcubes
import trimesh


def get_grid(grid_density=128):
    x = np.linspace(-1, 1, grid_density+1)
    y = np.linspace(-1, 1, grid_density+1)
    z = np.linspace(-1, 1, grid_density+1)
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device, non_blocking=True)
    return grid


@torch.inference_mode()
def query_latents_from_set(latent, point_queries, batch_size=None):
    num_samples = point_queries.shape[1]

    if batch_size is None:
        logits = ae.decode(latent, point_queries).detach()
        return logits

    logits = torch.cat([
        ae.decode(latent, point_queries[:, start_idx:end_idx, :]).detach().cpu()
        for start_idx, end_idx in batch_slices(num_samples, batch_size)
    ], dim=1)

    return logits


@torch.inference_mode()
def query_latents(latent, grid_density, batch_size=None):
    sample_grid = get_grid(grid_density)
    return query_latents_from_set(latent, sample_grid, batch_size)


@torch.inference_mode()
def decode_latents(latent, grid_density=128, batch_size=None):
    logits = query_latents(latent, grid_density, batch_size)
    volume = logits.view(grid_density+1, grid_density+1, grid_density+1).permute(1, 0, 2).cpu().numpy()
    verts, faces = mcubes.marching_cubes(volume, 0)
    gap = 2. / grid_density
    verts *= gap
    verts -= 1
    return trimesh.Trimesh(verts, faces)


def batch_slices(total, batch_size):
    """
    Generate batch start and end indices.
    """
    start = 0
    while start < total:
        end = min(start + batch_size, total)
        yield start, end
        start = end


@torch.inference_mode()
def encode_pc(pc):
    subsample__pc = pc[np.random.choice(pc.shape[0], ae.num_inputs, replace=False)] 
    _, x_a = ae.encode(torch.from_numpy(subsample__pc).to(device).unsqueeze(0).float())
    return x_a


@torch.inference_mode()
def encode_decode(pc, grid_density=128, batch_size=None):
    x_a = encode_pc(pc)
    return decode_latents(x_a, grid_density, batch_size), x_a


In [8]:
import zipfile
from metadata import SHAPENET_CLASSES, COMPAT_CLASSES, COMPAT_TRANSFORMS


ACTIVE_CLASS = "chair"
METADATA_DIR = "/ibex/user/slimhy/3DCoMPaT/3DCoMPaT-v2/metadata"
ZIP_SRC = "/ibex/user/slimhy/surfaces.zip"
ZIP_PATH = "/ibex/user/slimhy/3DCoMPaT/3DCoMPaT_ZIP.zip"
N_POINTS = 2**21


def shapenet_iterator(shape_cls):
    # List all files in the zip file
    with zipfile.ZipFile(ZIP_SRC, 'r') as zip_ref:
        files = zip_ref.namelist()

        for file in files:
            if not file.startswith(SHAPENET_CLASSES[shape_cls]): continue
            if not file.endswith(".npz"): continue
            # Read a specific file from the zip file
            with zip_ref.open(file) as file:
                data = np.load(file)
                yield data["points"].astype(np.float32)
        

def compat_iterator(shape_cls):
    train_dataset = ShapeLoader(zip_path=ZIP_PATH,
                                meta_dir=METADATA_DIR,
                                split="train",
                                n_points=N_POINTS,
                                shuffle=True,
                                seed=0,
                                filter_class=COMPAT_CLASSES[shape_cls])

    for shape_id, shape_label, pointcloud, point_part_labels in train_dataset:
        yield COMPAT_TRANSFORMS[shape_cls](pointcloud)

In [9]:
from shapeloaders import SingleManifoldDataset

OBJ_DIR = "/ibex/user/slimhy/PADS/data/obj_manifold/"
OBJ_ID = 0

# Freeze the autoencoder
ae.eval()

# Instantiate the dataset
shape_dataset = SingleManifoldDataset(OBJ_DIR,
                                      ACTIVE_CLASS,
                                      N_POINTS,
                                      normalize=False,
                                      sampling_method="volume+near_surface",
                                      contain_method="occnets",
                                      decimate=True,
                                      sample_first=False)
it_shape = shape_dataset[OBJ_ID]

# Initialize the latents
orig_dataset = SingleManifoldDataset(OBJ_DIR,
                                     ACTIVE_CLASS,
                                     N_POINTS,
                                     normalize=False,
                                     sampling_method="surface")

In [15]:
k = 0

In [18]:
k += 1
surface_points, _ = next(orig_dataset[OBJ_ID+k])
orig_mesh = orig_dataset.obj
orig_mesh.show()

Mesh is not watertight! Performing robust conversion...
Watertight conversion successful!


In [19]:
orig_mesh.is_watertight

True