In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [3]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

categories = ['chair']
path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

Loading (test) renders for chair (03001627):   0%|          | 0/1317 [00:00<?, ?it/s]

In [4]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [5]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [6]:
steps = 1000
scheduler = 'ddim'
ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{steps}-steps.ckpt'
ckpt = torch.load(ckpt_path, weights_only=False)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [7]:
model = model.cuda().eval()

In [8]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.visualization import display_pointclouds_grid

if scheduler == 'ddim':
    sched = DDIMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
    )
else:
    sched = DDPMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
    )

In [9]:
# from schedulers import create_sparse_scheduler
# sched = create_sparse_scheduler()

In [10]:
from tqdm.auto import tqdm

all_ref_pc = []
all_gen_pc = []

mean = torch.tensor(test_loader.dataset.mean).cuda()
std = torch.tensor(test_loader.dataset.std).cuda()

i = 0
for datapoint in tqdm(test_loader):
    i += 1
    if i > 5:
        continue
        
    ref_pc = datapoint['pc'].cuda()
    features = datapoint['render-features'].cuda()

    B, N, C = ref_pc.shape
    gen_pc = sched.sample(model, B, N, reference=features)
    # gen_pc = sched.sample(model, B, cond_emb=features, mode="conditional").cuda()

    # mean = gen_pc.mean(dim=1, keepdim=True)
    # std = gen_pc.std(dim=1, keepdim=True)
    # gen_pc = (gen_pc - mean) / std

    # mean = ref_pc.mean(dim=1, keepdim=True)
    # std = ref_pc.std(dim=1, keepdim=True)
    # ref_pc = (ref_pc - mean) / std

    # gen_pc = gen_pc * std + mean
    # ref_pc = ref_pc * std + mean

    all_ref_pc.append(ref_pc)
    all_gen_pc.append(gen_pc)

all_ref_pc = torch.cat(all_ref_pc).cuda()
all_gen_pc = torch.cat(all_gen_pc).cuda()

  0%|          | 0/42 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

In [11]:
print(all_ref_pc.mean())
print(all_gen_pc.mean())
print(all_ref_pc.max())
print(all_gen_pc.max())

tensor(0.0033, device='cuda:0')
tensor(-0.0728, device='cuda:0')
tensor(4.2445, device='cuda:0')
tensor(4.2944, device='cuda:0')


In [None]:
torch.save(all_gen_pc, "../generated.pt")

In [12]:
from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD
from pprint import pprint

In [13]:
results = compute_all_metrics(all_ref_pc, all_gen_pc, batch_size=32)
results = {k: (v.cpu().detach().item()
               if not isinstance(v, float) else v) for k, v in results.items()}

pprint(results)

  0%|          | 0/160 [00:00<?, ?it/s]

tensor([[0.0973, 0.7469, 1.0684,  ..., 0.2341, 0.6546, 0.4873],
        [0.1430, 0.6260, 1.1896,  ..., 0.1431, 0.4146, 0.3237],
        [0.6482, 1.8286, 0.0762,  ..., 1.1620, 0.7805, 0.3849],
        ...,
        [0.2150, 0.6572, 1.1578,  ..., 0.0739, 0.4854, 0.4084],
        [0.3580, 0.7539, 0.9558,  ..., 0.3391, 0.2168, 0.2658],
        [0.3879, 1.0635, 0.4969,  ..., 0.4535, 0.3070, 0.1849]])


  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

{'1-NN-CD-acc': 0.375,
 '1-NN-CD-acc_f': 0.3187499940395355,
 '1-NN-CD-acc_t': 0.4312500059604645,
 '1-NN-EMD-acc': 0.42500001192092896,
 '1-NN-EMD-acc_f': 0.48124998807907104,
 '1-NN-EMD-acc_t': 0.3687500059604645,
 'lgan_cov-CD': 0.612500011920929,
 'lgan_cov-EMD': 0.637499988079071,
 'lgan_mmd-CD': 0.09422929584980011,
 'lgan_mmd-EMD': 0.6649810075759888,
 'lgan_mmd_smp-CD': 0.09692485630512238,
 'lgan_mmd_smp-EMD': 0.6361186504364014}


In [55]:
jsd = JSD(all_gen_pc.cpu().numpy(), all_ref_pc.cpu().numpy())
pprint('JSD: {}'.format(jsd))

'JSD: 0.01925053807101662'
