In [1]:
%cd ../src
%load_ext autoreload
%autoreload 2

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [3]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"
categories = ['car']
test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=100)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

Loading (test) renders for car (02958343):   0%|          | 0/704 [00:00<?, ?it/s]

In [4]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [34]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler

steps = 32
ddim_sched = DDIMSparseScheduler(
    beta_min=hparams['beta_min'], 
    beta_max=hparams['beta_max'], 
    steps=steps, 
    init_steps=hparams['n_steps'],
    mode=hparams['mode'],
)

In [35]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [36]:
from utils.helper_functions import process_ckpt

ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/cond/{steps}-steps.ckpt'
ckpt = torch.load(ckpt_path, weights_only=False)
ckpt = process_ckpt(ckpt)
model.load_state_dict(ckpt)
model = model.eval().cuda()

In [37]:
bs = 4
ref = torch.stack([test_dataset[i]['render-features'] for i in range(bs)]).cuda()
pred, x0 = ddim_sched.sample(model, bs, 2048, 3, reference=ref, save=True)

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

In [38]:
from utils.visualization import display_pointclouds_grid

display = torch.stack([x0[i * steps // 10] for i in range(10)]).reshape(10 * 4, -1, 3)
display_pointclouds_grid(display.numpy(), offset=10, point_size=0.3, grid_dims=(10, 4))

Output()