In [2]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [3]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [4]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

categories = ['airplane']
path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

Loading (test) renders for airplane (02691156): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 808/808 [00:02<00:00, 373.61it/s]


In [5]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [6]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [7]:
steps = 1000
scheduler = 'ddpm'
ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{steps}-steps.ckpt'
ckpt = torch.load(ckpt_path, weights_only=False)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [8]:
model = model.cuda().eval()

In [9]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.visualization import display_pointclouds_grid

if scheduler == 'ddim':
    sched = DDIMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
    )
else:
    sched = DDPMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
    )

In [None]:
from tqdm.auto import tqdm

all_ref_pc = []
all_gen_pc = []

for datapoint in test_loader:
    ref_pc = datapoint['pc']
    features = datapoint['render-features'].cuda()

    B, N, C = ref_pc.shape
    gen_pc = sched.sample(model, B, N, reference=features, progress_bar=True)

    all_ref_pc.append(ref_pc.reshape(B * N, C))
    all_gen_pc.append(gen_pc.reshape(B * N, C))

all_ref_pc = torch.cat(all_ref_pc)
all_gen_pc = torch.cat(all_gen_pc)

In [None]:
from tqdm import tqdm as tq
from tqdm.auto import tqdm  # notebook compatible
import time
for i1 in tqdm(range(5)):
    for i2 in tq(range(300), desc="test", leave=False):
        # do something, e.g. sleep
        time.sleep(0.01)