In [1]:
%cd ../..

/ibex/user/slimhy/Shape2VecSet/code


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import h5py
import json
import numpy as np
import os
import os.path as osp
import pprint

import torch
import torch.backends.cudnn as cudnn
import warnings

from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
import models.autoencoders as ae_mods
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset

from argparse import ArgumentParser
from models.imnet import IMNetAE


def load_imnet():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    imnet = IMNetAE(sample_vox_size=64, device=device)

    # Load checkpoint
    ckpt_path = "/ibex/user/slimhy/Shape2VecSet/output/imnet/imnet_ckpt.pth"
    imnet.IMNet.load_state_dict(torch.load(ckpt_path))

    return imnet

In [3]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae_latent_dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )
    parser.add_argument(
        "--get_voxels",
        action="store_true",
        default=False,
        help="Fetch voxels instead of pointclouds.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [4]:
# Set dummy arg string to debug the parser
call_string = """--text_model_name bert-base-uncased \
    --ae_latent_dim 256 \
    --dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk/ \
    --data_type release_chained \
    --batch_size 1 \
    --num_workers 1 \
    --get_voxels \
    --device cuda \
    --fetch_keys \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [5]:
# --------------------
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True

dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------

# Instantiate autoencoder
print("Loading autoencoder [%s]..." % args.ae_pth)
imnet = load_imnet()
print("Done.")

print("Loading text model [%s]..." % args.text_model_name)
# Initialize text CLIP model
if args.use_clip:
    # Instantiate tokenizer + CLIP model
    tokenizer = AutoTokenizer.from_pretrained(args.text_model_name)
    text_model = CLIPTextModel.from_pretrained(args.text_model_name).to(device)
else:
    # Instantiate BERT model and create linear projection layer
    tokenizer = BertTokenizer.from_pretrained(args.text_model_name)
    text_model = BertModel.from_pretrained(args.text_model_name).to(device)

text_model = text_model.eval()
print("Done.")

Loading autoencoder [None]...
Done.
Loading text model [bert-base-uncased]...
Done.


In [6]:
def encode_voxels(voxels):
    imnet.IMNet = imnet.IMNet.eval()
    voxels = voxels.to(imnet.device)
    model_z, _ = imnet.IMNet(voxels)
    return model_z

In [7]:
iter_debug = iter(data_loader_train)

In [8]:
import trimesh
import mcubes

DEBUG_AE = True
if DEBUG_AE:
    voxels = next(iter_debug)
    voxels = voxels[1]

    # Safety check: fetch a single batch
    model_z = encode_voxels(voxels)
    model_float = imnet.z2voxel(model_z)

    vertices, triangles = mcubes.marching_cubes(
        model_float, imnet.sampling_threshold
    )
    vertices = (vertices.astype(np.float32) - 0.5) / imnet.real_size - 0.5


In [9]:
trimesh.Trimesh(vertices, triangles).show()

In [10]:
from tqdm.notebook import trange, tqdm


def check_key(key, key_set):
    key_a, key_b = key.split("_")
    return key_a in key_set or key_b in key_set


def get_split_embeddings(args, data_loader, missing_keys):
    text_latent_dim = 512 if args.use_clip else 768

    # Create stacked numpy arrays to store embeddings
    B = args.batch_size
    n_batches = len(data_loader)
    n_entries = n_batches * B
    embeds_xa   = np.zeros((n_entries, args.ae_latent_dim))
    embeds_xb   = np.zeros((n_entries, args.ae_latent_dim))
    embeds_text = np.zeros((n_entries, text_latent_dim))
    all_keys = ["_" for _ in range(n_entries)]

    # Iterate over the dataset and extract text embeddings
    # Use iterator instead
    data_iter = iter(data_loader)
    wrong_keys = []

    for k in trange(n_batches):
        # Catch ANY exception
        edit_keys, nodes_a, nodes_b, prompts_ab = next(data_iter)

        # Check that the key is in the missing keys
        if not check_key(edit_keys[0], missing_keys):
            continue
    
        nodes_a = nodes_a.to(device, non_blocking=True)
        nodes_b = nodes_b.to(device, non_blocking=True)
        embeds_ab = get_text_embeddings(text_model=text_model,
                                        tokenizer=tokenizer,
                                        texts=prompts_ab,
                                        device=device)

        with torch.no_grad():
            x_a = encode_voxels(nodes_a)
            x_b = encode_voxels(nodes_b)

        # Move batch to CPU and convert to numpy
        embeds_ab = embeds_ab.cpu().numpy()
        x_a = x_a.cpu().numpy()
        x_b = x_b.cpu().numpy()

        # Store to stacked arrays
        embeds_xa[k*B:(k+1)*B] = x_a.reshape(B, -1)
        embeds_xb[k*B:(k+1)*B] = x_b.reshape(B, -1)
        embeds_text[k*B:(k+1)*B] = embeds_ab.reshape(B, -1)
        all_keys[k*B:(k+1)*B] = edit_keys

    for key in wrong_keys:
       print("%s\n" % key)

    return all_keys, embeds_xa, embeds_xb, embeds_text


def extract_missing_embeddings(data_loader, missing_keys):
    """
    Extract embeddings and remap to shape/edit keys.
    """
    # Extract embeddings
    all_keys, embeds_xa, embeds_xb, embeds_text = get_split_embeddings(args, data_loader, missing_keys)

    edit_keys_sp = [k.split('_') for k in all_keys]
    keys_node_a = [k[0] for k in edit_keys_sp]
    keys_node_b = [k[1] for k in edit_keys_sp]
    
    # Map node keys to indices
    node_a_to_idx = {k: i for i, k in enumerate(keys_node_a)}
    node_b_to_idx = {k: i for i, k in enumerate(keys_node_b)}
    all_nodes = list(set(keys_node_a + keys_node_b))
    print("all_nodes=", len(all_nodes))

    # Build a matrix with all the embeddings
    # using the indices
    shape_embeds = np.zeros((len(all_nodes), args.ae_latent_dim))
    k = 0
    key_to_shape_embeds = {}
    for node_key in node_a_to_idx:
        idx = node_a_to_idx[node_key]
        shape_embeds[k] = embeds_xa[idx]
        key_to_shape_embeds[node_key] = k
        k += 1

    # Remove all nodes already added from node_b
    for node_key in (node_b_to_idx.keys() - node_a_to_idx.keys()):
        idx = node_b_to_idx[node_key]
        shape_embeds[k] = embeds_xb[idx]
        key_to_shape_embeds[node_key] = k
        k += 1

    # Double check that everything is correct
    # Iterate on edit_keys
    print("Checking shape embeddings...")
    intersec_nodes = node_a_to_idx.keys() & node_b_to_idx.keys()
    for node_a, node_b in edit_keys_sp:
        assert node_a in key_to_shape_embeds
        assert node_b in key_to_shape_embeds
    
        # Check that embeddings are correct
        idx_a = key_to_shape_embeds[node_a]
        idx_b = key_to_shape_embeds[node_b]
       
        if node_a not in intersec_nodes:
            assert np.allclose(shape_embeds[idx_a], embeds_xa[node_a_to_idx[node_a]])
        if node_b not in intersec_nodes:
            assert np.allclose(shape_embeds[idx_b], embeds_xb[node_b_to_idx[node_b]])
    print("Done!")

    key_pair_to_text_embeds = {key_pair : k for k, key_pair in enumerate(all_keys)}

    # Double check that everything is correct
    # Iterate on edit_keys
    print("Checking text embeddings...")
    for key_pair in all_keys:
        assert key_pair in key_pair_to_text_embeds

    print("Done!")

    return shape_embeds, key_to_shape_embeds, embeds_text, key_pair_to_text_embeds

In [11]:
def complete_hdf5(args, split, shape_embeds, key_to_shape_embeds, text_embeds, key_pair_to_text_embeds):
    """
    Open existing HDF5 file and replace the shape_embeds and text_embeds of the provided keys.
    """
    # Open HDF5 file
    hdf5_path = os.path.join(args.data_path, args.data_type, "embeddings_%s__imnet.hdf5" % split)
    print("Opening HDF5 file %s" % hdf5_path)
    f = h5py.File(hdf5_path, "r")

    # Resize f["shape_embeds"] to accomodate new embeddings
    n_new_embeds = len(key_to_shape_embeds) - 1
    n_old_embeds = f["shape_embeds"].shape[0]
    shape_embeds = np.zeros((n_old_embeds + n_new_embeds, args.ae_latent_dim))

    hdf5_key2shape_embed = json.loads(f["key_to_shape_embeds"][0])
    current_idx = len(hdf5_key2shape_embed)
    for key in key_to_shape_embeds:
        if key == '': continue

        # Load embeds from current buffers
        idx = key_to_shape_embeds[key]
        key_shape_embeds = shape_embeds[idx]

        # Append to HDF5
        shape_embeds[current_idx] = key_shape_embeds
        hdf5_key2shape_embed[key] = current_idx
        current_idx += 1

    # Resize f["text_embeds"] to accomodate new embeddings
    n_new_embeds = len(key_pair_to_text_embeds) - 1
    n_old_embeds = f["text_embeds"].shape[0]
    text_embeds = np.zeros((n_old_embeds + n_new_embeds, 512 if args.use_clip else 768))

    hdf5_key2pairtext_embed = json.loads(f["key_pair_to_text_embeds"][0])
    current_idx = len(hdf5_key2pairtext_embed)
    for key in key_pair_to_text_embeds:
        if key == '_': continue

        # Load embeds from current buffers
        idx = key_pair_to_text_embeds[key]
        key_text_embeds = text_embeds[idx]

        # Replace in HDF5
        text_embeds[current_idx] = key_text_embeds
        hdf5_key2pairtext_embed[key] = current_idx
        current_idx += 1

    # Close HDF5 file
    f.close()

    # Create new HDF5 file
    hdf5_path = os.path.join(args.data_path, args.data_type, "embeddings_%s__imnet_copy.hdf5" % split)
    print("Creating HDF5 file %s" % hdf5_path)
    f_new = h5py.File(hdf5_path, "w")

    # Delete empty keys from hdf5_key2shape_embed
    if '' in hdf5_key2pairtext_embed:
        del hdf5_key2pairtext_embed['']
    else:
        print("hmm")

    # Create datasets
    f_new.create_dataset("shape_embeds", data=shape_embeds)
    f_new.create_dataset("text_embeds", data=text_embeds)
    f_new.create_dataset("key_to_shape_embeds",
                     data=json.dumps(hdf5_key2shape_embed),
                     shape=(1,),
                     dtype=h5py.string_dtype(encoding="utf-8"))
    f_new.create_dataset("key_pair_to_text_embeds",
                     data=json.dumps(hdf5_key2pairtext_embed),
                     shape=(1,),
                     dtype=h5py.string_dtype(encoding="utf-8"))
    f_new.close()

    print("Done! Replaced %d embeddings." % len(key_to_shape_embeds))

In [12]:
missing_keys = set(['538085594287396785',
'511204417948777787',
'515890682497093265',
'567678021262807605',
'579732311956622259',
'535736676014233363',
'561481371803113784',
'547322459226286617'])

In [13]:
shape_embeds_train, key_to_shape_embeds_train, text_embeds_train, key_pair_to_text_embeds_train = extract_missing_embeddings(data_loader_train, missing_keys)

  0%|          | 0/79040 [00:00<?, ?it/s]

all_nodes= 24
Checking shape embeddings...
Done!
Checking text embeddings...
Done!


In [14]:
shape_embeds_train.shape, text_embeds_train.shape

((24, 256), (79040, 768))

In [15]:
shape_embeds_val, key_to_shape_embeds_val, text_embeds_val, key_pair_to_text_embeds_val = extract_missing_embeddings(data_loader_val, missing_keys)

  0%|          | 0/4180 [00:00<?, ?it/s]

all_nodes= 1
Checking shape embeddings...
Done!
Checking text embeddings...
Done!


In [16]:
shape_embeds_val.shape, text_embeds_val.shape

((1, 256), (4180, 768))

In [17]:
complete_hdf5(args, "train", shape_embeds_train, key_to_shape_embeds_train, text_embeds_train, key_pair_to_text_embeds_train)

Opening HDF5 file /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_train__imnet.hdf5
Creating HDF5 file /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_train__imnet_copy.hdf5
hmm
Done! Replaced 24 embeddings.


In [18]:
complete_hdf5(args, "val", shape_embeds_val, key_to_shape_embeds_val, text_embeds_val, key_pair_to_text_embeds_val)

Opening HDF5 file /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_val__imnet.hdf5


OSError: Unable to synchronously open file (truncated file: eof = 96, sblock->base_addr = 0, stored_eof = 2048)

In [None]:
!mv /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_val__imnet_copy.hdf5 /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_val__imnet.hdf5
!mv /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_train__imnet_copy.hdf5 /ibex/user/slimhy/ShapeWalk/release_chained/embeddings_train__imnet.hdf5