In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [13]:
## Hyperparameters
scheduler = 'ddim'
distilled = False
conditional = True

categories = ['chair']

In [14]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=5)

Loading (test) renders for chair (03001627):   0%|          | 0/1317 [00:00<?, ?it/s]

In [15]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [16]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [17]:
model = model.cuda().eval()

In [25]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps, step_size=1):
    if scheduler == 'ddim' and distilled:
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
            step_size=step_size,
        )
    elif distilled:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, step_size=1):
    if distilled:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'
    elif scheduler == 'ddim' and step_size > 1:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt
    

In [26]:
ref = test_dataset[0]
ref_img = ref["render-features"].unsqueeze(0).cuda() if conditional else None
real_pc = ref["pc"].numpy()

In [27]:
display = [real_pc]
start_noise = None

step = 1
ckpt = get_ckpt(1000, step_size=2)
model.load_state_dict(ckpt)

for steps in [1000, 500, 250, 125, 63, 32, 16]:
    # ckpt = get_ckpt(steps)
    # model.load_state_dict(ckpt)
    # sched = get_sched(steps)
    sched = get_sched(1000, step)
    step *= 2
    
    if start_noise is None:
        start_noise = sched.create_noise((1, 2048, 3), device='cuda')
    
    gen_pc = sched.sample(model, 1, 2048, 3, reference=ref_img, starting_noise=start_noise)
    display.append(gen_pc[0].cpu().detach().numpy())

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/16 [00:00<?, ?it/s]

In [28]:
from utils.visualization import display_pointclouds_grid
display_pointclouds_grid(display, offset=8, point_size=0.3, grid_dims=(2, 4))

Output()