In [None]:
"""
Chained evaluation of the models.
"""
import argparse
import datetime
import h5py
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset


def get_args_parser():
    parser = argparse.ArgumentParser("Performing Chained Eval", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_clip",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits", "graphedits_chained"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--chain_length",
        default=None,
        type=int,
        help="length of chains to load",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument(
        "--alt_ae_embeds",
        type=str,
        default=None,
        help="Alternative autoencoder embeddings to use",
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )
    parser.add_argument(
        "--model",
        type=str,
        metavar="MODEL",
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    return parser

In [None]:
import models.mlp_mapper as mlp_mapper

# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/pc_ae/best_model.pt \
    --ae-latent-dim 256 \
    --text_model_name bert-base-uncased \
    --dataset graphedits_chained \
    --data_path /ibex/user/slimhy/ShapeWalk/ \
    --data_type release_chained \
    --batch_size 1 \
    --chain_length 10 \
    --num_workers 8 \
    --model mlp_mapper_bert_bneck_512_pcae \
    --resume /ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_512_pcae__fine_chained/checkpoint-59.pth \
    --resume_full_weights \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --alt_ae_embeds pc_ae \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

model = mlp_mapper.__dict__[args.model](use_linear_proj=not args.use_clip)
model.to(device)

# Load the checkpoint
if args.resume:
    print("Loading checkpoint [%s]..." % args.resume)
    checkpoint = torch.load(args.resume, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    print("Done.")

In [None]:
dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

In [None]:
# --------------------
# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True
# --------------------

In [None]:
import json
import torch
import numpy as np
import os.path as osp
import pprint
import warnings

from argparse import ArgumentParser
from models.mlp import MLP
from models.point_net import PointNet
from models.pointcloud_autoencoder import PointcloudAutoencoder


def describe_pc_ae(args):
    # Make an AE.
    if args.encoder_net == "pointnet":
        ae_encoder = PointNet(init_feat_dim=3, conv_dims=args.encoder_conv_layers)
        encoder_latent_dim = args.encoder_conv_layers[-1]
    else:
        raise NotImplementedError()

    if args.decoder_net == "mlp":
        ae_decoder = MLP(
            in_feat_dims=encoder_latent_dim,
            out_channels=args.decoder_fc_neurons + [args.n_pc_points * 3],
            b_norm=False,
        )

    model = PointcloudAutoencoder(ae_encoder, ae_decoder)
    return model


def load_state_dicts(checkpoint_file, map_location=None, **kwargs):
    """ Load torch items from saved state_dictionaries"""
    if map_location is None:
        checkpoint = torch.load(checkpoint_file)
    else:
        checkpoint = torch.load(checkpoint_file, map_location=map_location)

    for key, value in kwargs.items():
        value.load_state_dict(checkpoint[key])

    epoch = checkpoint.get('epoch')
    if epoch:
        return epoch


def read_saved_args(config_file, override_or_add_args=None, verbose=False):
    """
    :param config_file: json file containing arguments
    :param override_args: dict e.g., {'gpu': '0'} will set the resulting arg.gpu to be 0
    :param verbose:
    :return:
    """
    parser = ArgumentParser()
    args = parser.parse_args([])
    with open(config_file, "r") as f_in:
        args.__dict__ = json.load(f_in)

    if override_or_add_args is not None:
        for key, val in override_or_add_args.items():
            args.__setattr__(key, val)

    if verbose:
        args_string = pprint.pformat(vars(args))
        print(args_string)

    return args


def load_pretrained_pc_ae(model_file):
    config_file = osp.join(osp.dirname(model_file), "config.json.txt")
    pc_ae_args = read_saved_args(config_file)
    pc_ae = describe_pc_ae(pc_ae_args)

    if osp.join(pc_ae_args.log_dir, "best_model.pt") != osp.abspath(model_file):
        warnings.warn(
            "The saved best_model.pt in the corresponding log_dir is not equal to the one requested."
        )

    best_epoch = load_state_dicts(model_file, model=pc_ae)
    print(f"Pretrained PC-AE is loaded at epoch {best_epoch}.")
    return pc_ae, pc_ae_args


# Instantiate autoencoder
print("Loading autoencoder [%s]..." % args.ae_pth)
pc_ae, pc_ae_args = load_pretrained_pc_ae(args.ae_pth)
pc_ae = pc_ae.to(device)
pc_ae = pc_ae.eval()
print("Done.")

In [None]:
from plot_pc import plot_pointclouds

def apply_edit(net, x_a, embed_ab):
    # Concatenate the latent vector with the embedding
    edit_vec = net(x_a, embed_ab)

    # Add the edit vector to the latent vector
    return x_a + edit_vec

def apply_iterated_edits(model, embeds_a, embeds_b, embeds_text):
    # Move all the garbage to CUDA
    embeds_a = embeds_a.cuda()
    embeds_b = embeds_b.cuda()
    embeds_text = embeds_text.cuda()

    x_b_edited = apply_edit(model, embeds_a, embeds_text)
    x_b = embeds_b
    x_a = embeds_a

    # Decode the batch
    b_size = x_b.shape[0]
    n_points = args.point_cloud_size

    with torch.inference_mode():
        orig = pc_ae.decoder(x_a).reshape([b_size, 4096, 3])
        rec = pc_ae.decoder(x_b_edited).reshape([b_size, 4096, 3])
        rec_gt = pc_ae.decoder(x_b).reshape([b_size, 4096, 3])

    return (orig, rec, rec_gt), (x_a, x_b_edited, x_b)

def apply_iterated_edits__SAFE(model, embeds_a, embeds_text):
    # Move all the garbage to CUDA
    embeds_a = embeds_a.cuda()
    embeds_text = embeds_text.cuda()

    x_b_edited = apply_edit(model, embeds_a, embeds_text)
    x_a = embeds_a

    # Decode the batch
    b_size = x_a.shape[0]
    n_points = args.point_cloud_size

    with torch.inference_mode():
        orig = pc_ae.decoder(x_a).reshape([b_size, 4096, 3])
        rec = pc_ae.decoder(x_b_edited).reshape([b_size, 4096, 3])

    return (orig, rec), (x_a, x_b_edited)

In [None]:
import json

json_f = "/ibex/user/slimhy/ShapeWalk/release_chained/release_chained_val.json"
json_d = json.load(open(json_f))

def get_prompt(edit_key):
    return json_d[edit_key]["prompt"]

In [None]:
# Loading HDF5 data
def decode_json_dset(dset):
    dset = dset[:][0].decode("utf-8")
    return json.loads(dset)

hdf5_file = os.path.join(
    args.data_path, args.data_type, "embeddings_val__pc_ae.hdf5"
)
hdf5_f = h5py.File(hdf5_file, "r")

# Load everything in RAM
shape_embeds = (
    torch.tensor(hdf5_f["shape_embeds"][:]).to("cpu").type(torch.float32)
)
D_shape_embeds = shape_embeds.cuda()
key_to_shape_embeds = decode_json_dset(hdf5_f["key_to_shape_embeds"])
shape_embeds_to_key = {v: k for k, v in key_to_shape_embeds.items()}


def query_embeds(query_embed, ignore_embed):
    """
    Returns the index of the closest embedding in the dataset,
    and the embedding itself.
    Ignore the given embedding.
    """
    # Get the index of the embedding to ignore
    ignore_idx = torch.where(D_shape_embeds == ignore_embed)[0]

    # Compute the L2 distance between the query and all embeddings,
    # except the one to ignore
    dists = torch.norm(D_shape_embeds - query_embed, dim=1)
    dists[ignore_idx] = float('inf')

    # Get the index of the closest embedding
    idx = torch.argmin(dists)

    # Return the index and the embedding
    return idx, shape_embeds_to_key[int(idx)], shape_embeds[idx]

In [None]:
with torch.no_grad():
    total_iter = 0
    all_orig = []
    skip_n = 0
    to_skip = 10
    k_filter = 0
    for _, edit_keys, node_a, _, text_embeds in dataset_val:
        skip_n += 1
        if skip_n <= to_skip: continue
        
        if total_iter == 0:
            prev_node = node_a

        prev_node = prev_node.repeat(8,1)
        text_embeds = text_embeds.repeat(8,1)

        (orig, _), (x_a, x_b_edited) = apply_iterated_edits__SAFE(model, prev_node, text_embeds)
        target_idx, target_key, x_b_edited = query_embeds(x_b_edited[0], x_a[0])

        all_orig += [[target_key, edit_keys]]

        prev_node = x_b_edited
        total_iter += 1
        if total_iter == 10:
            break

In [None]:
all_orig