In [1]:
%cd /ibex/user/slimhy/Shape2VecSet/code

/ibex/user/slimhy/Shape2VecSet/code


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
"""
Some experiments with embeddings proximity when sampling pointclouds.
"""
import argparse
import h5py
import json
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
import models.autoencoders as ae_mods
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset

In [3]:
def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [4]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/graph_edit/ae/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --text_model_name bert-base-uncased \
    --dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk_RND/ \
    --data_type release \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda \
    --fetch_keys \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [5]:
# --------------------
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True
dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae.to(device)

# Initialize text CLIP model
if args.use_clip:
    # Instantiate tokenizer + CLIP model
    tokenizer = AutoTokenizer.from_pretrained(args.text_model_name)
    text_model = CLIPTextModel.from_pretrained(args.text_model_name).to(device)
else:
    # Instantiate BERT model and create linear projection layer
    tokenizer = BertTokenizer.from_pretrained(args.text_model_name)
    text_model = BertModel.from_pretrained(args.text_model_name).to(device)



Loading autoencoder /ibex/user/slimhy/Shape2VecSet/output/graph_edit/ae/ae_m512.pth


In [6]:
def get_split_embeddings(args, data_loader):
    text_latent_dim = 512 if args.use_clip else 768

    # Create stacked numpy arrays to store embeddings
    B = args.batch_size
    n_batches = len(data_loader)
    n_entries = n_batches * B
    embeds_xa   = np.zeros((n_entries, args.ae_latent_dim))
    embeds_xb   = np.zeros((n_entries, args.ae_latent_dim))
    embeds_text = np.zeros((n_entries, text_latent_dim))
    all_keys = ["_" for _ in range(n_entries)]

    # Iterate over the dataset and extract text embeddings
    for k, (edit_keys, nodes_a, nodes_b, prompts_ab) in enumerate(tqdm(data_loader)):
        nodes_a = nodes_a.to(device, non_blocking=True)
        nodes_b = nodes_b.to(device, non_blocking=True)
        embeds_ab = get_text_embeddings(text_model=text_model,
                                        tokenizer=tokenizer,
                                        texts=prompts_ab,
                                        device=device)

        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad():
                _, x_a = ae.encode(nodes_a)
                _, x_b = ae.encode(nodes_b)

        # Move batch to CPU and convert to numpy
        embeds_ab = embeds_ab.cpu().numpy()
        x_a = x_a.cpu().numpy()
        x_b = x_b.cpu().numpy()

        # Store to stacked arrays
        embeds_xa[k*B:(k+1)*B] = x_a.reshape(B, -1)
        embeds_xb[k*B:(k+1)*B] = x_b.reshape(B, -1)
        embeds_text[k*B:(k+1)*B] = embeds_ab.reshape(B, -1)
        all_keys[k*B:(k+1)*B] = edit_keys

        #if k==2: break
    
    return all_keys, embeds_xa, embeds_xb, embeds_text


def extract_embeddings(split, data_loader):
    """
    Extract embeddings and remap to shape/edit keys.
    """
    # Extract embeddings
    all_keys, embeds_xa, embeds_xb, embeds_text = get_split_embeddings(args, data_loader)

    edit_keys_sp = [k.split('_') for k in all_keys]
    keys_node_a = [k[0] for k in edit_keys_sp]
    keys_node_b = [k[1] for k in edit_keys_sp]

    # Map node keys to indices
    node_a_to_idx = {k: i for i, k in enumerate(keys_node_a)}
    node_b_to_idx = {k: i for i, k in enumerate(keys_node_b)}
    all_nodes = list(set(keys_node_a + keys_node_b))
    print("all_nodes=", len(all_nodes))

    # Build a matrix with all the embeddings
    # using the indices
    shape_embeds = np.zeros((len(all_nodes), args.ae_latent_dim))
    k = 0
    key_to_shape_embeds = {}
    for node_key in node_a_to_idx:
        idx = node_a_to_idx[node_key]
        shape_embeds[k] = embeds_xa[idx]
        key_to_shape_embeds[node_key] = k
        k += 1

    # Remove all nodes already added from node_b
    for node_key in (node_b_to_idx.keys() - node_a_to_idx.keys()):
        idx = node_b_to_idx[node_key]
        shape_embeds[k] = embeds_xb[idx]
        key_to_shape_embeds[node_key] = k
        k += 1

    # Double check that everything is correct
    # Iterate on edit_keys
    print("Checking shape embeddings...")
    intersec_nodes = node_a_to_idx.keys() & node_b_to_idx.keys()
    for node_a, node_b in edit_keys_sp:
        assert node_a in key_to_shape_embeds
        assert node_b in key_to_shape_embeds
    
        # Check that embeddings are correct
        idx_a = key_to_shape_embeds[node_a]
        idx_b = key_to_shape_embeds[node_b]
       
        if node_a not in intersec_nodes:
            assert np.allclose(shape_embeds[idx_a], embeds_xa[node_a_to_idx[node_a]])
        if node_b not in intersec_nodes:
            assert np.allclose(shape_embeds[idx_b], embeds_xb[node_b_to_idx[node_b]])
    print("Done!")

    key_pair_to_text_embeds = {key_pair : k for k, key_pair in enumerate(all_keys)}

    # Double check that everything is correct
    # Iterate on edit_keys
    print("Checking text embeddings...")
    for key_pair in all_keys:
        assert key_pair in key_pair_to_text_embeds

    print("Done!")

    return shape_embeds, key_to_shape_embeds, embeds_text, key_pair_to_text_embeds

In [7]:
def create_hdf5(args, split, shape_embeds, key_to_shape_embeds, embeds_text, key_pair_to_text_embeds):
    """
    Create HDF5 file with the embeddings
    """
    # Create HDF5 file
    hdf5_path = os.path.join(args.data_path, args.data_type, "embeddings_%s.hdf5" % split)
    # If exists: delete
    if os.path.exists(hdf5_path):
        os.remove(hdf5_path)
        print("Deleted existing HDF5 file %s" % hdf5_path)
    print("Creating HDF5 file %s" % hdf5_path)
    f = h5py.File(hdf5_path, "w")

    # Create datasets
    f.create_dataset("shape_embeds", data=shape_embeds)
    f.create_dataset("text_embeds", data=embeds_text)
    f.create_dataset("key_to_shape_embeds",
                     data=json.dumps(key_to_shape_embeds),
                     shape=(1,),
                     dtype=h5py.string_dtype(encoding="utf-8"))
    f.create_dataset("key_pair_to_text_embeds",
                     data=json.dumps(key_pair_to_text_embeds),
                     shape=(1,),
                     dtype=h5py.string_dtype(encoding="utf-8"))
    f.close()

    print("Done!")

In [8]:
text_latent_dim = 512 if args.use_clip else 768

# Create stacked numpy arrays to store embeddings
B = args.batch_size
n_batches = len(data_loader_val)
n_entries = n_batches * B
embeds_xa   = np.zeros((n_entries, args.ae_latent_dim))
embeds_xb   = np.zeros((n_entries, args.ae_latent_dim))
embeds_text = np.zeros((n_entries, text_latent_dim))
all_keys = ["_" for _ in range(n_entries)]

# Iterate over the dataset and extract text embeddings
for k, (edit_keys, nodes_a, nodes_b, prompts_ab) in enumerate(tqdm(data_loader_val)):
    nodes_a = nodes_a.to(device, non_blocking=True)
    nodes_b = nodes_b.to(device, non_blocking=True)

    with torch.cuda.amp.autocast(enabled=False):
        with torch.no_grad():
            _, x_a = ae.encode(nodes_a)
            _, x_b = ae.encode(nodes_b)

    break

  0%|          | 0/93 [00:00<?, ?it/s]

In [9]:
import numpy as np
from scipy.optimize import linear_sum_assignment

def reorder_pointclouds(nodes_a, nodes_b):
    """
    Reorder pointclouds to minimize the distance between
    corresponding points.
    """
    # Compute pairwise distances
    dists = torch.cdist(nodes_a, nodes_b, p=2)

    # Compute optimal assignment
    row_ind, col_ind = linear_sum_assignment(dists.cpu().numpy())

    #col_ind = np.random.permutation(len(nodes_a))

    # Reorder nodes_b
    nodes_b = nodes_b.clone()
    nodes_b = nodes_b[col_ind]

    return nodes_a, nodes_b


In [13]:
def pwise_dist(pc_a, pc_b):
    dist_total = 0.0
    for i in range(pc_a.shape[0]):
        dist = (pc_a[i] - pc_b[i])**2
        dist_total += dist.mean()
    return dist_total/pc_a.shape[0]

def embed_dist(pc_a, pc_b):
    _, x_a = ae.encode(nodes_a)
    _, x_b = ae.encode(nodes_b)
    return (torch.abs(x_b - x_a)).mean()

In [14]:
# Compute pairwise distances
p_a, p_b = nodes_a[0].clone(), nodes_b[0].clone()
pwise_dist(p_a, p_b), embed_dist(p_a, p_b)

(tensor(0.2125, device='cuda:0'),
 tensor(0.9245, device='cuda:0', grad_fn=<MeanBackward0>))

In [15]:
# Compute pairwise distances
t_a, t_b = reorder_pointclouds(p_a, p_b)
pwise_dist(t_a, t_b), embed_dist(t_a, t_b)

(tensor(0.0013, device='cuda:0'),
 tensor(0.9240, device='cuda:0', grad_fn=<MeanBackward0>))