In [1]:
"""
Find furthest pairs of shapes in the dataset.
"""
import argparse
import datetime
import h5py
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset

In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=64,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        help="Experiment name to use",
    )
    parser.add_argument(
        "--debug_mode",
        action="store_true",
        default=False,
        help="Run in debug mode",
    )
    parser.add_argument(
        "--debug_with_forward",
        action="store_true",
        default=False,
        help="Run in debug mode, also run forward passes",
    )
    parser.add_argument(
        "--plateau_scheduler",
        action="store_true",
        default=False,
        help="Reduce LR on plateau",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--wandb_id",
        type=str,
        default=None,
        help="WandbID of the run to resume from",
    )
    parser.add_argument("--epochs", default=800, type=int)
    parser.add_argument(
        "--accum_iter",
        default=1,
        type=int,
        help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
    )
    parser.add_argument(
        "--valid_step",
        default=5,
        type=int,
        help="Log validation metrics every N epochs",
    )

    # Model parameters
    parser.add_argument(
        "--model",
        default="kl_d512_m512_l8_edm",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument(
        "--ae",
        default="kl_d512_m512_l8",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument("--ae_pth", help="Autoencoder checkpoint")
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )

    # Optimizer parameters
    parser.add_argument(
        "--clip_grad",
        type=float,
        default=None,
        metavar="NORM",
        help="Clip gradient norm (default: None, no clipping)",
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=None,
        metavar="LR",
        help="learning rate (absolute lr)",
    )
    parser.add_argument(
        "--blr",
        type=float,
        default=1e-4,
        metavar="LR",
        help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",
    )
    parser.add_argument(
        "--layer_decay",
        type=float,
        default=0.75,
        help="layer-wise lr decay from ELECTRA/BEiT",
    )
    parser.add_argument(
        "--min_lr",
        type=float,
        default=1e-6,
        metavar="LR",
        help="lower lr bound for cyclic schedulers that hit 0",
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )
    parser.add_argument(
        "--use_adam",
        action="store_true",
        default=False,
        help="Use Adam instead of AdamW.",
    )
    parser.add_argument(
        "--start_epoch", default=0, type=int, metavar="N", help="Start epoch"
    )
    parser.add_argument(
        "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
    )
    parser.add_argument("--seed", default=0, type=int)

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="Dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="Dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=1,
        type=int,
        help="Maximum edge level to use",
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
        help="Use precomputed embeddings",
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
        help="Fetch node keys in the dataloader",
    )

    # Checkpointing parameters
    parser.add_argument(
        "--output_dir",
        default="./output/",
        help="Path for saving weights/logs",
    )
    parser.add_argument(
        "--log_dir", default="./output/", help="Path where to tensorboard log"
    )
    parser.add_argument("--resume", default="", help="Resume from checkpoint")
    parser.add_argument(
        "--resume_weights",
        action="store_true",
        default=False,
        help="Only resume weights, not optimizer state",
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )
    parser.add_argument("--eval", action="store_true", help="Perform evaluation only")
    parser.add_argument(
        "--dist_eval",
        action="store_true",
        default=False,
        help="Enabling distributed evaluation (recommended during training for faster monitor",
    )

    # Hardware parameters
    parser.add_argument(
        "--device", default="cuda", help="Device to use for training / testing"
    )
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU",
    )
    parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
    parser.set_defaults(pin_mem=True)

    # Distributed training parameters
    parser.add_argument(
        "--world_size", default=1, type=int, help="Number of distributed processes"
    )
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist_on_itp", action="store_true")
    parser.add_argument(
        "--dist_url", default="env://", help="url used to set up distributed training"
    )

    return parser

In [3]:
# Set dummy arg string to debug the parser
call_string = """--dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk/ \
    --data_type release \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda \
    --fetch_keys \
    --text_model_name bert-base-uncased \
    --use_embeds \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [4]:
# --------------------
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True
dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------



In [5]:
from tqdm import tqdm

closest_pairs = {}

# Iterate over train set
for edit_key, nodes_a, nodes_b, embeds_ab in tqdm(data_loader_train):
    # Find the closest pair in embeddings from a to b
    nodes_a = nodes_a.to(device)
    nodes_b = nodes_b.to(device)
    embeds_ab = embeds_ab.to(device)

    # Compute pairwise distances between nodes a and b
    dist_mat = torch.norm(nodes_a - nodes_b, dim=1)
    
    # Iterate over the batch
    # map edit_key to distance in closest_pairs dict
    for i in range(len(edit_key)):
        closest_pairs[edit_key[i]] = dist_mat[i].item()    

100%|██████████| 39529/39529 [01:28<00:00, 444.65it/s]


In [7]:
# Dump to JSON
with open("closest_pairs_CD.json", "w") as f:
    json.dump(closest_pairs, f)