In [None]:
%cd /ibex/user/slimhy/PADS/code

In [None]:
# Add "3DCoMPaT" to the path
import sys
sys.path.append("jupyter/3DCoMPaT/")
import utils3D.plot as plt_utils
from compat3D import ShapeLoader

In [None]:
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import numpy as np
import torch

import util.misc as misc
import models.autoencoders as ae_mods

In [None]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [None]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [None]:
# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

In [None]:
import mcubes
import trimesh


GRID_DENSITY = 128
def get_grid():
    x = np.linspace(-1, 1, GRID_DENSITY+1)
    y = np.linspace(-1, 1, GRID_DENSITY+1)
    z = np.linspace(-1, 1, GRID_DENSITY+1)
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device, non_blocking=True)
    return grid

SAMPLE_GRID = get_grid()
@torch.inference_mode()
def decode_latent(latent):
    latent = torch.tensor(latent).cuda().reshape(1, 512, 8).type(torch.float32)
 
    logits = ae.decode(latent, SAMPLE_GRID)
    logits = logits.detach()

    volume = logits.view(GRID_DENSITY+1, GRID_DENSITY+1, GRID_DENSITY+1).permute(1, 0, 2).cpu().numpy()
    verts, faces = mcubes.marching_cubes(volume, 0)
 
    gap = 2. / GRID_DENSITY
    verts *= gap
    verts -= 1
 
    m = trimesh.Trimesh(verts, faces)
    return m

In [None]:
def normalize_pc(point_cloud, use_center_of_bounding_box=True):
    """
    Normalize the point cloud to be in the range [-1, 1] and centered at the origin.
    """
    min_x, max_x = torch.min(point_cloud[:, 0]), torch.max(point_cloud[:, 0])
    min_y, max_y = torch.min(point_cloud[:, 1]), torch.max(point_cloud[:, 1])
    min_z, max_z = torch.min(point_cloud[:, 2]), torch.max(point_cloud[:, 2])
    # center the point cloud
    if use_center_of_bounding_box:
        center = torch.tensor(
            [(min_x + max_x) / 2, (min_y + max_y) / 2, (min_z + max_z) / 2]
        )
    else:
        center = torch.mean(point_cloud, dim=0)
    point_cloud = point_cloud - center.to(point_cloud.device)
    dist = torch.max(torch.sqrt(torch.sum((point_cloud**2), dim=1)))
    point_cloud = point_cloud / dist  # scale the point cloud
    return point_cloud * 8.


In [None]:
def load_npy(path, scale):
    data = np.load(path)
    surface = data.astype(np.float32)
    return surface * scale

def flip_pc(pc):
    # Flip Z and Y axes
    pc = pc[:, [0, 2, 1]]
    return pc


In [None]:
import zipfile
from shapenet_synsets import SHAPENET_CLASSES

ZIP_SRC = "/ibex/user/slimhy/surfaces.zip"

def shapenet_iterator(shape_cls):
    # List all files in the zip file
    with zipfile.ZipFile(ZIP_SRC, 'r') as zip_ref:
        files = zip_ref.namelist()

        for file in files:
            if not file.startswith(SHAPENET_CLASSES[shape_cls]): continue
            if not file.endswith(".npz"): continue
            # Read a specific file from the zip file
            with zip_ref.open(file) as file:
                data = np.load(file)
                yield data["points"]

In [None]:
from shapenet_synsets import COMPAT_CLASSES

METADATA_DIR = "/ibex/user/slimhy/3DCoMPaT/3DCoMPaT-v2/metadata"
ZIP_PATH = "/ibex/user/slimhy/3DCoMPaT/3DCoMPaT_ZIP.zip"
N_POINTS = 2**15


def compat_iterator(shape_cls):
    train_dataset = ShapeLoader(zip_path=ZIP_PATH,
                                meta_dir=METADATA_DIR,
                                split="train",
                                n_points=N_POINTS,
                                shuffle=True,
                                seed=0,
                                filter_class=COMPAT_CLASSES[shape_cls])

    for shape_id, shape_label, pointcloud, point_part_labels in train_dataset:
        yield pointcloud, point_part_labels

In [None]:
# Iterate over compat classes
for compat_class in COMPAT_CLASSES:
    # Average min/max values
    avg_min = torch.zeros(3)
    avg_max = torch.zeros(3)
    
    total_shapes = 0

    # Iterate over all the point clouds in the class
    for pc, _ in compat_iterator(compat_class):
        pc_tensor = flip_pc(pc)
        pc_tensor = pc.astype(np.float32)
        pc_tensor = torch.tensor(pc_tensor)

        # Compute min/max across each dimension
        min_, max_ = torch.min(pc_tensor, dim=0)[0], torch.max(pc_tensor, dim=0)[0]
        
        # If there is a NaN value, skip the shape
        if torch.isnan(min_).any() or torch.isnan(max_).any():
            continue
        
        # Update the average min/max values
        avg_min += min_
        avg_max += max_
        
        total_shapes += 1
        
    avg_min /= total_shapes
    avg_max /= total_shapes
    
    print(f"Average min: {avg_min}, Average max: {avg_max}")
    
    # Log to npz file in the "stats" folder
    np.savez(f"stats_compat/{compat_class}.npz", avg_min=avg_min, avg_max=avg_max)
    print("Saved stats for [%s]. \n" % compat_class)

In [None]:
chair_it_sp = shapenet_iterator("chair")
chair_it_cp = compat_iterator("chair")

In [None]:
chair_tensor = next(chair_it_sp)
plt_utils.plot_pointclouds(
    [chair_tensor],
    size=8,
    cmap="viridis",
    point_size=2,
    semantic_level="coarse",
    pic_height=256,
)

In [None]:
chair_tensor = next(chair_it_cp)[0]
plt_utils.plot_pointclouds(
    [flip_pc(np.array(chair_tensor))],
    size=8,
    cmap="viridis",
    point_size=2,
    semantic_level="coarse",
    pic_height=256,
)

In [None]:
from shapenet_synsets import COMPAT_CLASSES

# Iterate over compat classes
for compat_class in COMPAT_CLASSES:
    # Average min/max values
    avg_min = torch.zeros(3)
    avg_max = torch.zeros(3)
    
    total_shapes = 0

    # Iterate over all the point clouds in the class
    for pc in shapenet_iterator(compat_class):
        pc_tensor = pc.astype(np.float32)
        pc_tensor = torch.tensor(pc_tensor)

        # Compute min/max across each dimension
        min_, max_ = torch.min(pc_tensor, dim=0)[0], torch.max(pc_tensor, dim=0)[0]
        
        # If there is a NaN value, skip the shape
        if torch.isnan(min_).any() or torch.isnan(max_).any():
            continue
        
        # Update the average min/max values
        avg_min += min_
        avg_max += max_
        
        total_shapes += 1
        
    avg_min /= total_shapes
    avg_max /= total_shapes
    
    print(f"Average min: {avg_min}, Average max: {avg_max}")
    
    # Log to npz file in the "stats" folder
    np.savez(f"stats_shapenet/{compat_class}.npz", avg_min=avg_min, avg_max=avg_max)
    print("Saved stats for [%s]. \n" % compat_class)