In [1]:
%cd ../../

/ibex/user/slimhy/Shape2VecSet/code


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
"""
Evaluation of the models.
"""
import argparse
import os.path as osp
import pprint
import json
import warnings

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import Sampler

import util.misc as misc
from util.misc import MetricLogger
from util.datasets import build_shape_surface_occupancy_dataset

import models.mlp_mapper as mlp_mapper
from models.mlp import MLP
from models.point_net import PointNet
from models.pointcloud_autoencoder import PointcloudAutoencoder

from losses.chamfer import chamfer_loss


def get_args_parser():
    parser = argparse.ArgumentParser("Performing Chained Eval", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_clip",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits", "graphedits_chained"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--chain_length",
        default=None,
        type=int,
        help="length of chains to load",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument(
        "--alt_ae_embeds",
        type=str,
        default=None,
        help="Alternative autoencoder embeddings to use",
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )
    parser.add_argument(
        "--model",
        type=str,
        metavar="MODEL",
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    return parser

Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [3]:
"""
xx mlp_mapper_bert_l1__256
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_l1__256/checkpoint-50.pth

xx mlp_mapper_bert_direct_latent_256
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_direct_latent_256/checkpoint-50.pth

====================================

xx mlp_mapper_bert_bneck_1024_pcae_cpl
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_1024_pcae__fine_chained_cpl/checkpoint-59.pth

xx mlp_mapper_bert_bneck_512_pcae_cpl
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_512_pcae__fine_cpl__chained/checkpoint-59.pth

xx mlp_mapper_bert_bneck_256_pcae_cpl
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_256_pcae__fine_cpl__chained/checkpoint-59.pth

xx mlp_mapper_bert_l8_pcae_cpl
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_l8_pcae__fine_cpl__chained/checkpoint-59.pth

xx mlp_mapper_bert_l4_pcae_cpl
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_l4_pcae__fine_cpl__chained/checkpoint-59.pth

====================================
xx mlp_mapper_bert_bneck_1024_pcae
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_1024_pcae__fine_chained/checkpoint-59.pth

xx mlp_mapper_bert_bneck_512_pcae
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_512_pcae__fine_chained/checkpoint-59.pth

xx mlp_mapper_bert_bneck_256_pcae
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_256_pcae__fine_chained/checkpoint-59.pth

xx mlp_mapper_bert_l8_pcae
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_l8_pcae__fine_chained/checkpoint-59.pth

xx mlp_mapper_bert_l4_pcae
/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_l4_pcae__fine_chained/checkpoint-59.pth
"""

# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/pc_ae/best_model.pt \
    --ae-latent-dim 256 \
    --text_model_name bert-base-uncased \
    --dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk_RND/ \
    --data_type release \
    --num_workers 8 \
    --model mlp_mapper_bert_bneck_1024_pcae \
    --resume /ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_1024_pcae__fine_chained/checkpoint-59.pth \
    --resume_full_weights \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --alt_ae_embeds pc_ae \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

model = mlp_mapper.__dict__[args.model](use_linear_proj=not args.use_clip)
model.to(device)

# Load the checkpoint
if args.resume:
    print("Loading checkpoint [%s]..." % args.resume)
    checkpoint = torch.load(args.resume, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    print("Done.")

Loading checkpoint [/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_bert_bneck_1024_pcae__fine_chained/checkpoint-59.pth]...
Done.


In [4]:
# --------------------
# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True
args.fetch_keys = True
args.fetch_intensity = True

dataset_train = build_shape_surface_occupancy_dataset("train", args=args)

# Create data loaders
data_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------



In [5]:
def describe_pc_ae(args):
    # Make an AE.
    if args.encoder_net == "pointnet":
        ae_encoder = PointNet(init_feat_dim=3, conv_dims=args.encoder_conv_layers)
        encoder_latent_dim = args.encoder_conv_layers[-1]
    else:
        raise NotImplementedError()

    if args.decoder_net == "mlp":
        ae_decoder = MLP(
            in_feat_dims=encoder_latent_dim,
            out_channels=args.decoder_fc_neurons + [args.n_pc_points * 3],
            b_norm=False,
        )

    model = PointcloudAutoencoder(ae_encoder, ae_decoder)
    return model


def load_state_dicts(checkpoint_file, map_location=None, **kwargs):
    """ Load torch items from saved state_dictionaries"""
    if map_location is None:
        checkpoint = torch.load(checkpoint_file)
    else:
        checkpoint = torch.load(checkpoint_file, map_location=map_location)

    for key, value in kwargs.items():
        value.load_state_dict(checkpoint[key])

    epoch = checkpoint.get('epoch')
    if epoch:
        return epoch


def read_saved_args(config_file, override_or_add_args=None, verbose=False):
    """
    :param config_file: json file containing arguments
    :param override_args: dict e.g., {'gpu': '0'} will set the resulting arg.gpu to be 0
    :param verbose:
    :return:
    """
    parser = argparse.ArgumentParser()
    args = parser.parse_args([])
    with open(config_file, "r") as f_in:
        args.__dict__ = json.load(f_in)

    if override_or_add_args is not None:
        for key, val in override_or_add_args.items():
            args.__setattr__(key, val)

    if verbose:
        args_string = pprint.pformat(vars(args))
        print(args_string)

    return args


def load_pretrained_pc_ae(model_file):
    config_file = osp.join(osp.dirname(model_file), "config.json.txt")
    pc_ae_args = read_saved_args(config_file)
    pc_ae = describe_pc_ae(pc_ae_args)

    if osp.join(pc_ae_args.log_dir, "best_model.pt") != osp.abspath(model_file):
        warnings.warn(
            "The saved best_model.pt in the corresponding log_dir is not equal to the one requested."
        )

    best_epoch = load_state_dicts(model_file, model=pc_ae)
    print(f"Pretrained PC-AE is loaded at epoch {best_epoch}.")
    return pc_ae, pc_ae_args


# Instantiate autoencoder
print("Loading autoencoder [%s]..." % args.ae_pth)
pc_ae, pc_ae_args = load_pretrained_pc_ae(args.ae_pth)
pc_ae = pc_ae.to(device)
pc_ae = pc_ae.eval()
print("Done.")

Loading autoencoder [/ibex/user/slimhy/Shape2VecSet/output/pc_ae/best_model.pt]...
Pretrained PC-AE is loaded at epoch 186.
Done.




In [22]:
def apply_edit(net, x_a, embed_ab):
    # Reshape from (B, D, K) to (B, M)
    x_a = x_a.flatten(1)
    embed_ab = embed_ab.flatten(1)

    # Concatenate the latent vector with the embedding
    edit_vec = net(x_a, embed_ab)

    # Add the edit vector to the latent vector
    #return edit_vec
    return edit_vec + x_a

def apply_edit__gt(net, x_a, x_b, embed_ab):
    # Reshape from (B, D, K) to (B, M)
    x_a = x_a.flatten(1)
    embed_ab = embed_ab.flatten(1)

    # Concatenate the latent vector with the embedding
    edit_dir, magnitude = net.forward_decoupled(x_a, embed_ab)
    gt_vec = (x_b-x_a)
    opt_direction = torch.zeros_like(edit_dir)
    for i in range(x_a.shape[0]):
        opt_direction[i] = gt_vec[i] / (torch.norm(gt_vec[i]) + 1e-8)

    # Add the edit vector to the latent vector
    return opt_direction*magnitude + x_a

def apply_edit__gt_mag(net, x_a, x_b, embed_ab):
    # Reshape from (B, D, K) to (B, M)
    x_a = x_a.flatten(1)
    embed_ab = embed_ab.flatten(1)

    # Concatenate the latent vector with the embedding
    edit_dir, magnitude = net.forward_decoupled(x_a, embed_ab)
    gt_vec = (x_b-x_a)
    opt_mag = torch.zeros_like(magnitude)
    for i in range(x_a.shape[0]):
        opt_mag[i] = gt_vec[i] / (torch.norm(gt_vec[i]) + 1e-8)

    # Add the edit vector to the latent vector
    return opt_direction*magnitude + x_a

def apply_iterated_edits(model, embeds_a, embeds_b, embeds_text, use_gt=False):
    # Move all the garbage to CUDA
    embeds_a = embeds_a.cuda()
    embeds_b = embeds_b.cuda()
    embeds_text = embeds_text.cuda()

    if use_gt:
        x_b_edited = apply_edit__gt(model, embeds_a, embeds_b, embeds_text)
    else:
        x_b_edited = apply_edit(model, embeds_a, embeds_text)
    x_b = embeds_b
    x_a = embeds_a

    # Decode the batch
    b_size = x_b.shape[0]

    with torch.inference_mode():
        orig = pc_ae.decoder(x_a).reshape([b_size, 4096, 3])
        rec = pc_ae.decoder(x_b_edited).reshape([b_size, 4096, 3])
        rec_gt = pc_ae.decoder(x_b).reshape([b_size, 4096, 3])

    return (orig, rec, rec_gt), (x_a, x_b_edited, x_b)


In [None]:
from tqdm import tqdm

model = model.eval()
metric_meter = MetricLogger()

def get_affected_param(edge_dict):
    edge_dict = json.loads(edge_dict)
    edge = list(edge_dict.keys())
    if len(edge) > 2:
        return "err"
    else:
        return edge[0]

def get_metrics(data_loader):
    with torch.no_grad():
        for batch_k, (node_ids, node_a, node_b, text_embeds, edge_dict) in tqdm(enumerate(data_loader), total=len(data_loader)):          
            # Apply the edits
            (orig, rec, rec_gt), (x_a, x_b_edited, x_b) = apply_iterated_edits(model,
                                                                               embeds_a=node_a,
                                                                               embeds_b=node_b,
                                                                               embeds_text=text_embeds,
                                                                               use_gt=True)
    
            # Compute average pairwise L2 distance in feature space
            l2_dist = torch.norm(x_b_edited - x_b, p=2)
            l2_scale = torch.norm(x_b_edited - x_a, p=2)
    
            # Compute average pairwise CD
            cd_dist = chamfer_loss(rec, rec_gt, reduction="mean").mean()
            cd_scale = chamfer_loss(rec, orig, reduction="mean").mean()
    
            metric_meter.update(avg_l2_dist=l2_dist.item(),
                                avg_l2_scale=l2_scale.item(),
                                cd_dist=cd_dist.item(),
                                cd_scale=cd_scale.item())
       
            # Also log accuracy per parameter affected
            # use get_affected_param(edge_dict) which returns a single affected parameter
            unique_affected_params = [[get_affected_param(edge_dict[i]), i] for i in range(len(edge_dict))]
            # Map the affected parameters to the corresponding edge_dict index
            param_to_idx = {param: i for param, i in unique_affected_params}
            for param, param_idx in param_to_idx.items():
                rec_param = rec[param_idx].unsqueeze(0)
                rec_gt_param = rec_gt[param_idx].unsqueeze(0)
                param_cd_dist = chamfer_loss(rec_param, rec_gt_param, reduction="mean").mean()
                param_l2_dist = torch.norm(x_b_edited[param_idx] - x_b[param_idx], p=2)
                metric_meter.update(**{param + "_cd_dist": param_cd_dist.item(),
                                       param + "_l2_dist": param_l2_dist.item()})

    return metric_meter

metric_meter = get_metrics(data_loader)

 30%|██▉       | 8543/28500 [02:22<05:27, 60.89it/s]

In [None]:
"""
defaultdict(util.misc.SmoothedValue,
            {'avg_l2_dist': <util.misc.SmoothedValue at 0x152c06037610>,
             'avg_l2_scale': <util.misc.SmoothedValue at 0x152c06037640>,
             'cd_dist': <util.misc.SmoothedValue at 0x152c06037700>,
             'cd_scale': <util.misc.SmoothedValue at 0x152c06037760>,
             'v_legs_bevel_cd_dist': <util.misc.SmoothedValue at 0x152c060377c0>,
             'v_legs_bevel_l2_dist': <util.misc.SmoothedValue at 0x152c060377f0>,
             'v_z_cd_dist': <util.misc.SmoothedValue at 0x152c06037850>,
             'v_z_l2_dist': <util.misc.SmoothedValue at 0x152c060378b0>,
             'v_seat_pos_cd_dist': <util.misc.SmoothedValue at 0x152c06037910>,
             'v_seat_pos_l2_dist': <util.misc.SmoothedValue at 0x152c06037970>,
             'v_tr_scale_z_cd_dist': <util.misc.SmoothedValue at 0x152c060379d0>,
             'v_tr_scale_z_l2_dist': <util.misc.SmoothedValue at 0x152c06037a30>,
             'v_legs_shape_1_cd_dist': <util.misc.SmoothedValue at 0x152c06037a90>,
             'v_legs_shape_1_l2_dist': <util.misc.SmoothedValue at 0x152c06037af0>,
             'v_back_leg_bottom_y_offset_pct_cd_dist': <util.misc.SmoothedValue at 0x152c06037b50>,
             'v_back_leg_bottom_y_offset_pct_l2_dist': <util.misc.SmoothedValue at 0x152c06037bb0>,
             'v_y_cd_dist': <util.misc.SmoothedValue at 0x152c06037c10>,
             'v_y_l2_dist': <util.misc.SmoothedValue at 0x152c06037c70>,
             'v_seat_shape_cd_dist': <util.misc.SmoothedValue at 0x152c06037cd0>,
             'v_seat_shape_l2_dist': <util.misc.SmoothedValue at 0x152c06037d30>,
             'v_tr_shape_1_cd_dist': <util.misc.SmoothedValue at 0x152c06037d90>,
             'v_tr_shape_1_l2_dist': <util.misc.SmoothedValue at 0x152c06037df0>,
             'v_curvature_cd_dist': <util.misc.SmoothedValue at 0x152c06037e50>,
             'v_curvature_l2_dist': <util.misc.SmoothedValue at 0x152c06037eb0>,
             'v_back_leg_mid_y_offset_pct_cd_dist': <util.misc.SmoothedValue at 0x152c06037f10>,
             'v_back_leg_mid_y_offset_pct_l2_dist': <util.misc.SmoothedValue at 0x152c06037f70>,
             'b_is_handles_cusion_cd_dist': <util.misc.SmoothedValue at 0x152c06037fd0>,
             'b_is_handles_cusion_l2_dist': <util.misc.SmoothedValue at 0x152c075fc070>,
             'v_x_cd_dist': <util.misc.SmoothedValue at 0x152c075fc0d0>,
             'v_x_l2_dist': <util.misc.SmoothedValue at 0x152c075fc130>,
             'err_cd_dist': <util.misc.SmoothedValue at 0x152c075fc190>,
             'err_l2_dist': <util.misc.SmoothedValue at 0x152c075fc2b0>,
             'v_tr_scale_y_cd_dist': <util.misc.SmoothedValue at 0x152c075fc400>,
             'v_tr_scale_y_l2_dist': <util.misc.SmoothedValue at 0x152c075fc430>,
             'v_legs_shape_2_cd_dist': <util.misc.SmoothedValue at 0x152c075fc460>,
             'v_legs_shape_2_l2_dist': <util.misc.SmoothedValue at 0x152c075fc4c0>,
             'v_monoleg_tent_count_cd_dist': <util.misc.SmoothedValue at 0x152c075fc760>,
             'v_monoleg_tent_count_l2_dist': <util.misc.SmoothedValue at 0x152c075fc8b0>,
             'v_cr_count_cd_dist': <util.misc.SmoothedValue at 0x152c075fc8e0>,
             'v_cr_count_l2_dist': <util.misc.SmoothedValue at 0x152c075fca90>,
             'v_vr_count_cd_dist': <util.misc.SmoothedValue at 0x152c075fd360>,
             'v_vr_count_l2_dist': <util.misc.SmoothedValue at 0x152c075fd4b0>})
"""
param_groups = {
    "seat_height": ["v_seat_pos"],
    "backrest curvature": ["v_curvature"],
    "object width/height/depth": ["v_x", "v_x", "v_y"],
    "seat roudness": ["v_seat_shape"],
    "top bar thickness/height": ["v_tr_shape_1", "v_tr_shape_1"],
    "legs thickness": ["v_legs_shape_1", "v_legs_shape_2"],
    "adding/removing handle cushions": ["b_is_handles_cusion"],
    "number of legs/backrest rails": ["v_monoleg_tent_count", "v_cr_count", "v_vr_count"],
    "legs bending/curvature/roundness/indentation": ["v_legs_bevel", "v_legs_bevel"],
}

# Group the metrics by parameter group in metric_logger.meters
for param_group, param_list in param_groups.items():
    group_avg_l2_dist = 0
    group_avg_cd_dist = 0
    for param in sorted(param_list):
        group_avg_cd_dist += metric_meter.meters[param + "_cd_dist"].avg*(10**4)
        group_avg_l2_dist += metric_meter.meters[param + "_l2_dist"].avg
    group_avg_l2_dist /= len(param_list)
    group_avg_cd_dist /= len(param_list)
    print(f"{param_group} & {group_avg_l2_dist:.4f} & {group_avg_cd_dist:.4f} \\\\")