In [1]:
import os
import time
import math
import argparse
import torch
from tqdm.auto import tqdm

from utils.dataset import *
from utils.misc import *
from utils.data import *
from models.vae_gaussian import *
from models.vae_flow import *
from models.flow import add_spectral_norm, spectral_norm_power_iteration
from evaluation import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Model
ckpt = torch.load("")
model = GaussianVAE(ckpt['args']).to(args.device)
model.load_state_dict(ckpt['state_dict'])

In [None]:
# Datasets and loaders
dataset_path = ""
batch_size = 0
test_dset = ShapeNetCore(
    path=dataset_path,
    cates="1111",
    split='test',
    scale_mode=None,
)
test_loader = DataLoader(test_dset, batch_size=batch_size, num_workers=0)

In [None]:

# Reference Point Clouds
ref_pcs = []
for i, data in enumerate(test_dset):
    ref_pcs.append(data['pointcloud'].unsqueeze(0))
ref_pcs = torch.cat(ref_pcs, dim=0)

# Generate Point Clouds
gen_pcs = []
for i in tqdm(range(0, math.ceil(len(test_dset) / args.batch_size)), 'Generate'):
    with torch.no_grad():
        z = torch.randn([args.batch_size, ckpt['args'].latent_dim]).to(args.device)
        x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility)
        gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)[:len(test_dset)]
if args.normalize is not None:
    gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger)