In [1]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD



In [3]:
model_args = {
    'voxel_size' : 0.1,
    'nfs' : (32, 64, 128, 256), 
    'attn_chans' : 8, 
    'attn_start' : 3, 
    'cross_attn_chans' : 8, 
    'cross_attn_start' : 2, 
    'cross_attn_cond_dim' : 768,
}

model = SPVUnet(**model_args)
model = GSPVD(model=model, lr=0.1, training_steps=1)

In [4]:
ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/GSPVD/all_categories_renders/checkpoints/epoch=399-step=111600.ckpt'
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

In [5]:
model = model.cuda().eval()

In [6]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from utils.visualization import display_pointclouds_grid

ddpm_sched = DDPMSparseScheduler(beta_min=0.0001, beta_max=0.02, steps=1000)

In [7]:
from datasets.shapenet.shapenet_loader import ShapeNet

categories = ['skateboard']
path = "../data/ShapeNet"

te = ShapeNet(path, "test", 2048, categories, load_renders=True)

Loading (test) renders for skateboard (04225987): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 36.89it/s]


In [8]:
import numpy as np

samples = 16
references = [te[idx] for idx in np.random.choice(list(range(len(te))), size=(samples,))]

In [9]:
reference_images = torch.stack([r["render-features"] for r in references]).to("cuda")

In [10]:
preds = ddpm_sched.sample(model, 16, 2048, reference=reference_images)

Sampling:   0%|▍                                                                                                                                                            | 3/1000 [00:00<02:21,  7.05it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   1%|█                                                                                                                                                            | 7/1000 [00:00<01:16, 12.96it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   1%|█▋                                                                                                                                                          | 11/1000 [00:00<01:00, 16.21it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   2%|██▎                                                                                                                                                         | 15/1000 [00:01<00:54, 17.93it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   2%|██▉                                                                                                                                                         | 19/1000 [00:01<00:52, 18.85it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   2%|███▌                                                                                                                                                        | 23/1000 [00:01<00:50, 19.28it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   3%|████▏                                                                                                                                                       | 27/1000 [00:01<00:49, 19.51it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   3%|████▊                                                                                                                                                       | 31/1000 [00:01<00:49, 19.61it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   4%|█████▍                                                                                                                                                      | 35/1000 [00:02<00:49, 19.68it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   4%|██████                                                                                                                                                      | 39/1000 [00:02<00:48, 19.72it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   4%|██████▋                                                                                                                                                     | 43/1000 [00:02<00:48, 19.70it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   5%|███████▎                                                                                                                                                    | 47/1000 [00:02<00:48, 19.70it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   5%|███████▉                                                                                                                                                    | 51/1000 [00:02<00:48, 19.65it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])


Sampling:   6%|████████▌                                                                                                                                                   | 55/1000 [00:03<00:54, 17.41it/s]

torch.Size([32768, 32])
torch.Size([32768, 32])
torch.Size([32768, 32])





KeyboardInterrupt: 

In [None]:
display_pointclouds_grid(preds.cpu().numpy(), offset=8, point_size=0.3)

In [None]:
real = torch.stack([r["pc"] for r in references]).numpy()
display_pointclouds_grid(real, offset=8, point_size=0.3)