In [1]:
"""
Find furthest pairs of shapes in the dataset using CD.
"""
import argparse
import datetime
import h5py
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset

In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [3]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/graph_edit/ae/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --text_model_name bert-base-uncased \
    --dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk/ \
    --data_type release \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda \
    --fetch_keys \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [4]:
# --------------------
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True
dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------




In [None]:
from tqdm import tqdm
import torch, chamfer3D.dist_chamfer_3D
from chamfer3D.fscore import fscore

closest_pairs = {}

with torch.no_grad():
    # Iterate over train set
    for edit_key, nodes_a, nodes_b, embeds_ab in tqdm(data_loader_train):
        # Find the closest pair in embeddings from a to b
        nodes_a = nodes_a.to(device)
        nodes_b = nodes_b.to(device)
        embeds_ab = embeds_ab.to(device)
    
        # Compute pairwise chamfer distances between nodes a and nodes b
        dist_mat = torch.norm(nodes_a - nodes_b, dim=1)
        
        # Iterate over the batch
        # map edit_key to distance in closest_pairs dict
        for i in range(len(edit_key)):
            closest_pairs[edit_key[i]] = dist_mat[i].item()    