In [None]:
%cd ../src
%load_ext autoreload
%autoreload 2

In [None]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [None]:
## Hyperparameters
scheduler = 'ddim'
conditional = True

categories = ['airplane']
steps_to_run = [1000,]# 500, 250, 125, 63, 32, 16, 8, 4, 2]

In [None]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=200)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

fig, ax = plt.subplots(10, 5, figsize=(20, 20))

offset = 0
for i in range(50):
    r = test_dataset[i + offset]
    img = Image.open("../data/ShapeNet/renders/" + r['filename'] + f"/00{r['selected-view']}.png")
    
    ax[i // 5, i % 5].imshow(img)
    ax[i // 5, i % 5].axis('off')
plt.show()

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

target = 48
fig, ax = plt.subplots(2, 4, figsize=(20, 10))

for i in range(8):
    r = test_dataset[target, i]
    img = Image.open("../data/ShapeNet/renders/" + r['filename'] + f"/00{r['selected-view']}.png")
    
    ax[i // 4, i % 4].imshow(img)
    ax[i // 4, i % 4].axis('off')
plt.show()

In [None]:
ref = test_dataset[48, 6]
ref_img = ref["render-features"].unsqueeze(0).cuda() if conditional else None
real_pc = ref["pc"].numpy()

In [None]:
from PIL import Image
img = Image.open("../data/ShapeNet/renders/" + ref['filename'] + f"/00{ref['selected-view']}.png")
img

In [None]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [None]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [None]:
model = model.cuda().eval()

In [None]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps):
    if scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif distilled:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, conditiona, step_size=1):
    if distilled:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/new/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        # ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'
        ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/1000-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt
    

In [None]:
import time

display = [real_pc]
distilled = False
if distilled:
    start_noise = None

# ckpt = get_ckpt(2, conditional)
# model.load_state_dict(ckpt)

for steps in steps_to_run:
    ckpt = get_ckpt(steps, conditional)
    model.load_state_dict(ckpt)
    sched = get_sched(steps)
    
    if start_noise is None:
        start_noise = sched.create_noise((1, 2048, 3), device='cuda')

    start_time = time.time()
    gen_pc = sched.sample(model, 1, 2048, 3, reference=ref_img, starting_noise=start_noise, guidance_scale=1)
    # print(f"Generated point cloud in {time.time() - start_time:.2f} seconds for {steps} steps")
    display.append(gen_pc[0].cpu().detach().numpy())

In [None]:
from utils.visualization import display_pointclouds_grid
display_pointclouds_grid(display, offset=10, point_size=0.3, grid_dims=(3, 4))

In [None]:
from utils.helper_functions import send_to_local
if distilled:
    send_to_local(display, ref['filename'], ref['selected-view'], f"distillation-{steps_to_run}-2", categories[0])
else:
    send_to_local(display, ref['filename'], ref['selected-view'], f"ddim-{steps_to_run}-2", categories[0])