In [1]:
"""
Chained evaluation of the models.
"""
%cd ..

import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn

import util.misc as misc
from util.misc import MetricLogger
from util.datasets import build_shape_surface_occupancy_dataset

import models.mlp_mapper as mlp_mapper

from losses.chamfer import chamfer_loss

# Silence torch warnings
import warnings

warnings.filterwarnings("ignore")


def get_args_parser():
    parser = argparse.ArgumentParser("Performing Chained Eval", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        "(effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512 * 8,
        help="AE latent dimension",
    )
    parser.add_argument("--ae_pth", help="Autoencoder checkpoint")
    parser.add_argument("--point_cloud_size", default=2048, type=int, help="input size")
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_clip",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits", "graphedits_chained"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--chain_length",
        default=None,
        type=int,
        help="length of chains to load",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument(
        "--alt_ae_embeds",
        type=str,
        default=None,
        help="Alternative autoencoder embeddings to use",
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )
    parser.add_argument(
        "--model",
        type=str,
        metavar="MODEL",
    )
    parser.add_argument("--resume", default="", help="Resume from checkpoint")
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    return parser

/ibex/user/slimhy/Shape2VecSet/code
Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [2]:
table_header = """
\\begin{tabular}{@{}lcccccccccccccccccccccccccc@{}}
    \\toprule
    \multicolumn{1}{c}{\multirow{2}{*}{\\textbf{Model}}} & \multicolumn{1}{c}{\multirow{2}{*}{decoupled?}}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 10}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 15}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 20} \\\\ \cmidrule(l){4-23}
    \multicolumn{1}{c}{}  & \multicolumn{1}{c}{}
    &
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Rec}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Rec}}$}
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Real}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Real}}$}
    & \multicolumn{1}{c}{$\\textsc{F}_{\mathcal{L}_2}$}
    & \multicolumn{1}{c}{$\\textsc{A}_{\mathcal{L}_2}$}
    &
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Rec}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Rec}}$}
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Real}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Real}}$}
    & \multicolumn{1}{c}{$\\textsc{F}_{\mathcal{L}_2}$}
    & \multicolumn{1}{c}{$\\textsc{A}_{\mathcal{L}_2}$}
    &
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Rec}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Rec}}$}
    & $\\textsc{F}_{\\textsc{CD} - \\textsc{Real}}$
    & \multicolumn{1}{c}{$\\textsc{A}_{\\textsc{CD} - \\textsc{Real}}$}
    & \multicolumn{1}{c}{$\\textsc{F}_{\mathcal{L}_2}$}
    & \multicolumn{1}{c}{$\\textsc{A}_{\mathcal{L}_2}$} \\\\ \midrule
"""
table_entry = "    & & %0.3f & \multicolumn{1}{c}{%0.3f} & \multicolumn{1}{c}{%0.3f} & \multicolumn{1}{c}{%0.3f} & \multicolumn{1}{c}{%0.3f} & \multicolumn{1}{c}{%0.3f}"
table_sep = "\\arrayrulecolor{black!30}\midrule\\arrayrulecolor{black!100}"
table_footer = """
    \\bottomrule
\end{tabular}%
"""
table_is_decoupled = "& \multicolumn{1}{c}{\%s}"

In [3]:
from transforms import get_pc_ae_transform
from eval.chain_sampler import ChainSampler
from eval.metrics import l2_dist, chamfer_reconstructed, chamfer_real


def init_exps(model_name, model_path, ae_model):
    """
    Initialize the latent space mapper and args.
    """
    # Set dummy arg string to debug the parser
    call_string = """--ae-latent-dim 256 \
        --text_model_name bert-base-uncased \
        --dataset graphedits_chained \
        --data_path /ibex/user/slimhy/ShapeWalk/ \
        --data_type release_chained \
        --num_workers 8 \
        --model %s \
        --resume %s \
        --resume_full_weights \
        --device cuda \
        --fetch_keys \
        --use_embeds \
        --alt_ae_embeds %s \
        --seed 0""" % (model_name, model_path, ae_model)

    # Parse the arguments
    args = get_args_parser()
    args = args.parse_args(call_string.split())
    args.use_clip = "clip" in args.text_model_name
    device = torch.device(args.device)

    model = mlp_mapper.__dict__[args.model](use_linear_proj=not args.use_clip)
    model.to(device)

    # Load the checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model.load_state_dict(checkpoint["model"])

    return args, model, device


def get_loader(args, batch_size, chain_length):
    """
    Get the data loader for chained evaluation.
    """
    args.batch_size = batch_size
    args.chain_length = chain_length

    dataset_val = build_shape_surface_occupancy_dataset("val", args=args)
    chain_sampler = ChainSampler(
        dataset_val, batch_size=args.batch_size, chain_length=args.chain_length
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
        sampler=chain_sampler,
    )

    return data_loader_val


def apply_edit(net, x_a, embed_ab):
    """
    Apply the edit to the latent vector.
    """
    # Reshape from (B, D, K) to (B, M)
    x_a = x_a.flatten(1)
    embed_ab = embed_ab.flatten(1)

    # Concatenate the latent vector with the embedding
    edit_vec = net(x_a, embed_ab)

    # Add the edit vector to the latent vector
    return edit_vec + x_a


def apply_iterated_edits(model, ae_model, embeds_a, embeds_b, embeds_text):
    """
    Apply the edits iteratively.
    """
    # Move all the garbage to CUDA
    embeds_a = embeds_a.cuda()
    embeds_b = embeds_b.cuda()
    embeds_text = embeds_text.cuda()

    x_b_edited = apply_edit(model, embeds_a, embeds_text)
    x_b = embeds_b
    x_a = embeds_a

    # Decode the batch
    b_size = x_b.shape[0]

    with torch.inference_mode():
        orig = ae_model.decoder(x_a).reshape([b_size, 4096, 3])
        rec = ae_model.decoder(x_b_edited).reshape([b_size, 4096, 3])
        rec_gt = ae_model.decoder(x_b).reshape([b_size, 4096, 3])

    return (orig, rec, rec_gt), (x_a, x_b_edited, x_b)


def get_metrics(args, model, ae_model, data_loader, drop_n, pc_t):
    """
    Get the metrics for chained evaluation.
    """
    metric_meter = MetricLogger()

    with torch.no_grad():
        chain_count = 0
        for batch_k, (chain_ids, edit_keys, node_a, node_b, text_embeds) in enumerate(
            data_loader
        ):
            if batch_k == len(data_loader) - 1 - drop_n:
                break
            if chain_count == 0:
                prev_node = node_a

            # Apply the edits
            (p_a, p_b_pred, p_b), (x_a, x_b_pred, x_b) = apply_iterated_edits(
                model,
                ae_model,
                embeds_a=prev_node,
                embeds_b=node_b,
                embeds_text=text_embeds,
            )

            # Compute average pairwise L2 distance in feature space
            l2_distance = l2_dist(x_b, x_b_pred)

            # Compute average pairwise reconstructed CD
            cd_dist_reco = chamfer_reconstructed(p_b, p_b_pred)

            # Compute average pairwise real CD
            cd_dist_real = chamfer_real(
                p_edited=p_b_pred,
                node_gt=edit_keys,
                transform=pc_t,
                data_path=args.data_path,
                n_samples=12,
            )

            # Log all metrics
            metric_meter.update(
                avg_l2_dist=l2_distance.item(),
                avg_cd_dist_reco=cd_dist_reco.item(),
                avg_cd_dist_real=cd_dist_real.item(),
            )

            prev_node = x_b_pred

            # Log final chain metrics
            chain_count += 1
            if chain_count == args.chain_length:
                metric_meter.update(
                    final_l2_dist=l2_distance.item(),
                    final_cd_dist_reco=cd_dist_reco.item(),
                    final_cd_dist_real=cd_dist_real.item(),
                )
                chain_count = 0

    return metric_meter


def run_exps(args, model, ae_model, device):
    """
    Run the chained evaluations for different chain lengths.
    """
    # Fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    args.fetch_keys = True

    # Get the point cloud transform
    pc_t = get_pc_ae_transform(args)
    # --------------------

    # Instantiate the model
    model = model.eval()

    all_metrics = {
        "final_cd_dist_reco": [],
        "avg_cd_dist_reco": [],
        "final_cd_dist_real": [],
        "avg_cd_dist_real": [],
        "final_l2_dist": [],
        "avg_l2_dist": [],
    }
    for chain_length, batch_size, drop_n in [[10, 8, 0], [15, 4, 3], [20, 3, 0]]:
        data_loader_val = get_loader(
            args, batch_size=batch_size, chain_length=chain_length
        )
        metric_meter = get_metrics(args, model, ae_model, data_loader_val, drop_n, pc_t)

        # Print the results
        for k, v in metric_meter.meters.items():
            if "cd" in k:
                all_metrics[k].append(v.global_avg * 10**4)
            else:
                all_metrics[k].append(v.global_avg)

        # print(table_entry % (final_cd_dist, avg_cd_dist, final_l2_dist, avg_l2_dist))
        print(table_entry % tuple(all_metrics[k][-1] for k in all_metrics.keys()))

    full_results = { k: np.mean(v) for k, v in all_metrics.items() }
    return full_results

In [4]:
from models.pc_ae import load_pretrained_pc_ae


CKPT_ROOT = "/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/"
MODEL_MAP = [
    {
        "model_name": "mlp_mapper_bert_l1__256",
        "checkpoint": "mlp_mapper_bert_l1__256/checkpoint-50.pth",
        "method_code": r"""\textsc{DirectGen}_{\textsc{Linear}}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_1024_pcae_cpl",
        "checkpoint": "mlp_mapper_bert_bneck_1024_pcae__fine_chained_cpl/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{1024}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_512_pcae_cpl",
        "checkpoint": "mlp_mapper_bert_bneck_512_pcae__fine_cpl__chained/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{512}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_256_pcae_cpl",
        "checkpoint": "mlp_mapper_bert_bneck_256_pcae__fine_cpl__chained/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{256}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_l8_pcae_cpl",
        "checkpoint": "mlp_mapper_bert_l8_pcae__fine_cpl__chained/checkpoint-59.pth",
        "method_code": r"""\textsc{Ours}_{512 \times 8}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_l4_pcae_cpl",
        "checkpoint": "mlp_mapper_bert_l4_pcae__fine_cpl__chained/checkpoint-59.pth",
        "method_code": r"""\textsc{Ours}_{512 \times 4}""",
        "is_decoupled": False,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_1024_pcae",
        "checkpoint": "mlp_mapper_bert_bneck_1024_pcae__fine_chained/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{1024}""",
        "is_decoupled": True,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_512_pcae",
        "checkpoint": "mlp_mapper_bert_bneck_512_pcae__fine_chained/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{512}""",
        "is_decoupled": True,
    },
    {
        "model_name": "mlp_mapper_bert_bneck_256_pcae",
        "checkpoint": "mlp_mapper_bert_bneck_256_pcae__fine_chained/checkpoint-59.pth",
        "method_code": r"""\textsc{LateFusion}_{256}""",
        "is_decoupled": True,
    },
    {
        "model_name": "mlp_mapper_bert_l8_pcae",
        "checkpoint": "mlp_mapper_bert_l8_pcae__fine_chained/checkpoint-59.pth",
        "method_code": r"""\textsc{Ours}_{512 \times 8}""",
        "is_decoupled": True,
    },
    {
        "model_name": "mlp_mapper_bert_l4_pcae",
        "checkpoint": "mlp_mapper_bert_l4_pcae__fine_chained/checkpoint-59.pth",
        "method_code": r"""\textsc{Ours}_{512 \times 4}""",
        "is_decoupled": True,
    },
]
AE_MODEl_NAME = "pc_ae"
AE_MODEL_PATH = "/ibex/user/slimhy/Shape2VecSet/ckpt/pc_ae/best_model.pt"

line_seps = [0, 5]
ae_model = None
all_results = {}
print(table_header)
for k, model_map in enumerate(MODEL_MAP):
    args, model, device = init_exps(
        model_map["model_name"], CKPT_ROOT + model_map["checkpoint"], AE_MODEl_NAME
    )
    # Instantiate autoencoder
    if ae_model is None:
        ae_model, _ = load_pretrained_pc_ae(AE_MODEL_PATH)
        ae_model = ae_model.to(device)
        ae_model = ae_model.eval()

    print("$" + model_map["method_code"] + "$")
    print(table_is_decoupled % ("icoyes" if model_map["is_decoupled"] else "icono"))
    full_results = run_exps(args, model, ae_model, device)
    all_results[model_map["model_name"]] = full_results
    print("\\\\")
    print()
    if k in line_seps:
        print(table_sep)
        print()

print(table_footer)


\begin{tabular}{@{}lcccccccccccccccccccccccccc@{}}
    \toprule
    \multicolumn{1}{c}{\multirow{2}{*}{\textbf{Model}}} & \multicolumn{1}{c}{\multirow{2}{*}{decoupled?}}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 10}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 15}
    &  & \multicolumn{6}{c}{$|\mathcal{P}|$ = 20} \\ \cmidrule(l){4-23}
    \multicolumn{1}{c}{}  & \multicolumn{1}{c}{}
    &
    & $\textsc{F}_{\textsc{CD} - \textsc{Rec}}$
    & \multicolumn{1}{c}{$\textsc{A}_{\textsc{CD} - \textsc{Rec}}$}
    & $\textsc{F}_{\textsc{CD} - \textsc{Real}}$
    & \multicolumn{1}{c}{$\textsc{A}_{\textsc{CD} - \textsc{Real}}$}
    & \multicolumn{1}{c}{$\textsc{F}_{\mathcal{L}_2}$}
    & \multicolumn{1}{c}{$\textsc{A}_{\mathcal{L}_2}$}
    &
    & $\textsc{F}_{\textsc{CD} - \textsc{Rec}}$
    & \multicolumn{1}{c}{$\textsc{A}_{\textsc{CD} - \textsc{Rec}}$}
    & $\textsc{F}_{\textsc{CD} - \textsc{Real}}$
    & \multicolumn{1}{c}{$\textsc{A}_{\textsc{CD} - \textsc{Real}}$}
    & \multicolumn

In [11]:
table_header = r"""
\begin{tabular}{@{}lllccccccc@{}}
    \toprule
    \multicolumn{1}{c}{\multirow{2}{*}{\textbf{Model}}} & \multicolumn{1}{c}{\multirow{2}{*}{Decoupled magnitude}} & \multicolumn{1}{c}{} & \multicolumn{6}{c}{Averaged $\ \forall |\mathcal{P}|$} \\  \cmidrule(l){4-9} 
    \multicolumn{1}{c}{} & \multicolumn{1}{c}{}
    & & $\textsc{F}_{\textsc{CD}}$
    & $\textsc{A}_{\textsc{CD}}$
    & $\textsc{F}_{\textsc{CD} - \textsc{Real}}$
    & $\textsc{A}_{\textsc{CD} - \textsc{Real}}$
    & $\textsc{F}_{\mathcal{L}_2}$ & $\textsc{A}_{\mathcal{L}_2}$ \\ \midrule
"""
table_entry = "    & & %0.3f & %0.3f & %0.3f & %0.3f & %0.3f & %0.3f"
table_footer = r"""\bottomrule
\end{tabular}%"""

In [12]:
print(table_header)
line_seps = [0, 5]
for k, model_map in enumerate(MODEL_MAP):
    results_entry = all_results[model_map["model_name"]]
    print("$" + model_map["method_code"] + "$")
    print(table_is_decoupled % ("icoyes" if model_map["is_decoupled"] else "icono"))
    print(
        table_entry
        % (
            results_entry["final_cd_dist_reco"],
            results_entry["avg_cd_dist_reco"],
            results_entry["final_cd_dist_real"],
            results_entry["avg_cd_dist_real"],
            results_entry["final_l2_dist"],
            results_entry["avg_l2_dist"],
        )
    )
    print("\\\\")
    print()
    if k in line_seps:
        print(table_sep)
        print()
print(table_footer)


\begin{tabular}{@{}lllccccccc@{}}
    \toprule
    \multicolumn{1}{c}{\multirow{2}{*}{\textbf{Model}}} & \multicolumn{1}{c}{\multirow{2}{*}{Decoupled magnitude}} & \multicolumn{1}{c}{} & \multicolumn{6}{c}{Averaged $\ \forall |\mathcal{P}|$} \\  \cmidrule(l){4-9} 
    \multicolumn{1}{c}{} & \multicolumn{1}{c}{}
    & & $\textsc{F}_{\textsc{CD}}$
    & $\textsc{A}_{\textsc{CD}}$
    & $\textsc{F}_{\textsc{CD} - \textsc{Real}}$
    & $\textsc{A}_{\textsc{CD} - \textsc{Real}}$
    & $\textsc{F}_{\mathcal{L}_2}$ & $\textsc{A}_{\mathcal{L}_2}$ \\ \midrule

$\textsc{DirectGen}_{\textsc{Linear}}$
& \multicolumn{1}{c}{\icono}
    & & 3.140 & 2.497 & 4.054 & 3.900 & 2.150 & 1.819
\\

\arrayrulecolor{black!30}\midrule\arrayrulecolor{black!100}

$\textsc{LateFusion}_{1024}$
& \multicolumn{1}{c}{\icono}
    & & 1.831 & 1.318 & 2.764 & 2.559 & 1.444 & 1.251
\\

$\textsc{LateFusion}_{512}$
& \multicolumn{1}{c}{\icono}
    & & 1.673 & 1.296 & 2.658 & 2.548 & 1.462 & 1.290
\\

$\textsc{LateFusion}_{2