In [1]:
import json
import torch
import numpy as np
import os.path as osp
import pprint
import warnings

from argparse import ArgumentParser
from models.mlp import MLP
from models.point_net import PointNet
from models.pointcloud_autoencoder import PointcloudAutoencoder


def describe_pc_ae(args):
    # Make an AE.
    if args.encoder_net == "pointnet":
        ae_encoder = PointNet(init_feat_dim=3, conv_dims=args.encoder_conv_layers)
        encoder_latent_dim = args.encoder_conv_layers[-1]
    else:
        raise NotImplementedError()

    if args.decoder_net == "mlp":
        ae_decoder = MLP(
            in_feat_dims=encoder_latent_dim,
            out_channels=args.decoder_fc_neurons + [args.n_pc_points * 3],
            b_norm=False,
        )

    model = PointcloudAutoencoder(ae_encoder, ae_decoder)
    return model


def load_state_dicts(checkpoint_file, map_location=None, **kwargs):
    """ Load torch items from saved state_dictionaries"""
    if map_location is None:
        checkpoint = torch.load(checkpoint_file)
    else:
        checkpoint = torch.load(checkpoint_file, map_location=map_location)

    for key, value in kwargs.items():
        value.load_state_dict(checkpoint[key])

    epoch = checkpoint.get('epoch')
    if epoch:
        return epoch


def read_saved_args(config_file, override_or_add_args=None, verbose=False):
    """
    :param config_file: json file containing arguments
    :param override_args: dict e.g., {'gpu': '0'} will set the resulting arg.gpu to be 0
    :param verbose:
    :return:
    """
    parser = ArgumentParser()
    args = parser.parse_args([])
    with open(config_file, "r") as f_in:
        args.__dict__ = json.load(f_in)

    if override_or_add_args is not None:
        for key, val in override_or_add_args.items():
            args.__setattr__(key, val)

    if verbose:
        args_string = pprint.pformat(vars(args))
        print(args_string)

    return args


def load_pretrained_pc_ae(model_file):
    config_file = osp.join(osp.dirname(model_file), "config.json.txt")
    pc_ae_args = read_saved_args(config_file)
    pc_ae = describe_pc_ae(pc_ae_args)

    if osp.join(pc_ae_args.log_dir, "best_model.pt") != osp.abspath(model_file):
        warnings.warn(
            "The saved best_model.pt in the corresponding log_dir is not equal to the one requested."
        )

    best_epoch = load_state_dicts(model_file, model=pc_ae)
    print(f"Pretrained PC-AE is loaded at epoch {best_epoch}.")
    return pc_ae, pc_ae_args

Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [2]:
pretrained_shape_generator = '/home/slimhy/Documents/changeit3d/data/pretrained/pc_autoencoders/pointnet/rs_2022/points_4096/all_classes/scaled_to_align_rendering/08-07-2022-22-23-42/best_model.pt'
gpu_id = 0

device = torch.device("cuda:" + str(gpu_id))
pc_ae, pc_ae_args = load_pretrained_pc_ae(pretrained_shape_generator)
pc_ae = pc_ae.to(device)
pc_ae = pc_ae.eval()

Pretrained PC-AE is loaded at epoch 186.


In [3]:
import os
import glob

pc_dir = "/home/slimhy/Documents/datasets/ShapeWalk_RND/release/pc/"
pc_files = glob.glob(os.path.join(pc_dir, "*.npy"))

# Sample 16 pointclouds
pc_files = np.random.choice(pc_files, 16)

# Load them into a tensor of size 16 x N x 3
pc_batch = torch.stack([torch.from_numpy(np.load(f)) for f in pc_files], dim=0).float().to(device)

# Sample 16 x 4096 x 3 from 16 x N x 3
pc_batch = pc_batch[:, :4096, :]

In [4]:
# Decode the batch
rec, _, avg_cd = pc_ae.reconstruct(loader=[{"pointcloud": pc_batch}], device=device, loss_rule='chamfer')