In [1]:
"""
Neural listener evaluation of the models.
"""
import argparse
import datetime
import h5py
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset


def get_args_parser():
    parser = argparse.ArgumentParser("Performing Chained Eval", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_clip",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits", "graphedits_chained", "graphedits_nrl"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--chain_length",
        default=None,
        type=int,
        help="length of chains to load",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument(
        "--alt_ae_embeds",
        type=str,
        default=None,
        help="Alternative autoencoder embeddings to use",
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )
    parser.add_argument(
        "--model",
        type=str,
        metavar="MODEL",
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    return parser

In [2]:
import models.listeners as listener

# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/pc_ae/best_model.pt \
    --ae-latent-dim 256 \
    --num_workers 4 \
    --batch_size 64 \
    --text_model_name "bert-base-uncased" \
    --dataset graphedits_nrl \
    --data_path /ibex/user/slimhy/ShapeWalk_RND/ \
    --data_type release \
    --model nrl_listener_bert_256_pcae \
    --resume /ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/nrl_listener_bert_256_chained/checkpoint-799.pth \
    --resume_full_weights \
    --use_embeds \
    --alt_ae_embeds pc_ae \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

model = listener.__dict__[args.model](use_linear_proj=not args.use_clip)
model.to(device)

# Load the checkpoint
if args.resume:
    print("Loading checkpoint [%s]..." % args.resume)
    checkpoint = torch.load(args.resume, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    print("Done.")

Loading checkpoint [/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/nrl_listener_bert_256_chained/checkpoint-799.pth]...
Done.


In [3]:
# --------------------
# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True
args.fetch_keys = True
args.fetch_intensity = True

dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------

In [8]:
from util.misc import AverageMeter, MetricLogger
from losses.chamfer import chamfer_loss
from losses.listener_loss import ListenerLoss

model = model.eval()
metric_meter = MetricLogger()
criterion = ListenerLoss()

def get_affected_param(edge_dict):
    edge_dict = json.loads(edge_dict)
    edge = list(edge_dict.keys())
    if len(edge) > 2:
        return "err"
    else:
        return edge[0]

with torch.no_grad():
    for batch_k, (edit_keys, x_a, x_b, embed_ab, edge_dict, labels) in enumerate(data_loader_val):
        x_a = x_a.to(device)
        x_b = x_b.to(device)
        embed_ab = embed_ab.to(device)
        labels = labels.to(device)

        loss, pred = criterion(model, x_a, x_b, embed_ab, labels)
        acc = (pred > 0.5).eq(labels).float().mean()

        # Also log accuracy per parameter affected
        # use get_affected_param(edge_dict) which returns a single affected parameter
        unique_affected_params = [[get_affected_param(edge_dict[i]), i] for i in range(len(edge_dict))]
        # Map the affected parameters to the corresponding edge_dict index
        param_to_idx = {param: i for param, i in unique_affected_params}
        for param, param_idx in param_to_idx.items():
            param_acc = (pred[param_idx] > 0.5).eq(labels[param_idx]).float().mean()
            metric_meter.update(**{param: param_acc.item()})
        
        metric_meter.update(global_acc=acc.item())

print(metric_meter.global_acc.global_avg)

0.8845108695652174


In [9]:
print(str(metric_meter))

v_legs_shape_1: 1.0000 (0.9565)	v_back_leg_mid_y_offset_pct: 1.0000 (0.9556)	v_y: 1.0000 (0.9783)	v_x: 1.0000 (0.9783)	b_is_handles_cusion: 1.0000 (0.7500)	v_cr_count: 1.0000 (0.6667)	v_seat_shape: 1.0000 (0.9783)	v_z: 1.0000 (1.0000)	v_tr_scale_z: 1.0000 (0.9556)	v_seat_pos: 1.0000 (1.0000)	err: 1.0000 (0.7561)	v_curvature: 1.0000 (1.0000)	v_legs_shape_2: 1.0000 (0.5556)	v_back_leg_bottom_y_offset_pct: 1.0000 (0.9333)	v_legs_bevel: 1.0000 (0.5870)	v_monoleg_tent_count: 1.0000 (0.6571)	v_tr_shape_1: 1.0000 (0.9773)	global_acc: 0.8906 (0.8845)	v_tr_scale_y: 1.0000 (0.9512)	v_vr_count: 1.0000 (0.6667)


In [None]:
print("model_name:", args.model)
for chain_length, batch_size, drop_n in [[10, 8, 0], [15, 4, 3], [20, 3, 0]]:
    #print("chain_length=", chain_length, "batch_size=", batch_size)
    data_loader_val = get_loader(args, batch_size=batch_size, chain_length=chain_length)
    metric_meter = get_metrics(data_loader_val)

    #print("|P| = %d" % chain_length)
    table_str = "& & FCD & \multicolumn{1}{c}{ACD} & \multicolumn{1}{c}{FL2} & \multicolumn{1}{c}{AL2} "
    table_str = table_str.replace("FCD", "%0.3f" % round(metric_meter.final_cd_dist.global_avg*10**4, 3))
    table_str = table_str.replace("ACD", "%0.3f" % round(metric_meter.cd_dist.global_avg*10**4, 3))
    
    table_str = table_str.replace("FL2", "%0.3f" % round(metric_meter.final_l2_dist.global_avg, 3))
    table_str = table_str.replace("AL2", "%0.3f" % round(metric_meter.avg_l2_dist.global_avg, 3))

    print(table_str)
print("\\\\")