In [None]:
%cd ../src
%load_ext autoreload
%autoreload 2

In [None]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [15]:
## Hyperparameters
steps_to_run = [1000, 500, 250, 125, 63, 32]
on_all = True
scheduler = 'ddpm'
distilled = True
step_size = 1

categories = ['airplane']

In [16]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=800 if on_all else 5)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

Loading (test) renders for airplane (02691156):   0%|          | 0/808 [00:00<?, ?it/s]

In [17]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [18]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [19]:
model = model.cuda().eval()

In [20]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps, step_size=1):
    if scheduler == 'ddim' and distilled:
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
            step_size=step_size,
        )
    elif distilled:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, conditional, step_size=1):
    if distilled:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{"cond" if conditional else "uncond"}/{steps}-steps.ckpt'
    elif scheduler == 'ddim' and step_size > 1:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [22]:
from tqdm.auto import tqdm
from metrics.evaluation_metrics import cham3D, EMD
from utils.helper_functions import normalize_to_unit_sphere

def run_test(steps, step_size, scheduler):
    if distilled:
        ckpt = get_ckpt(steps, conditional=True)
    else:
        ckpt = get_ckpt(steps, conditional=True, step_size=step_size)

    model.load_state_dict(ckpt, strict=False)
    model.eval()

    sched = get_sched(steps, step_size=step_size)

    cd_mean = 0
    emd_mean = 0
    i = 0
    
    for datapoint in tqdm(test_loader):
        i += 1
        ref_pc = datapoint['pc'].cuda()
        features = datapoint['render-features'].cuda()

        B, N, C = ref_pc.shape
        gen_pc = sched.sample(model, B, N, reference=features, guidance_scale=1)
        
        ref_pc = normalize_to_unit_sphere(ref_pc)
        gen_pc = normalize_to_unit_sphere(gen_pc)

        d1, d2 = cham3D(ref_pc, gen_pc)
        cd = d1.mean(dim=1) + d2.mean(dim=1)
        emd = EMD(ref_pc, gen_pc, transpose=False)
        
        cd_mean += cd.mean()
        emd_mean += emd.mean()
    
    cd_mean /= i
    emd_mean /= i
    
    steps = len(sched.t_steps)
    
    print(f"Steps: {steps}, CD: {cd_mean.item()}, EMD: {emd_mean.item()}")
    return cd_mean, emd_mean, steps
    
means = [run_test(steps, step_size, scheduler) for steps in steps_to_run]
    

  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Steps: 1000, CD: 0.008592253550887108, EMD: 0.05579596012830734


  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Steps: 500, CD: 0.009172680787742138, EMD: 0.057976674288511276


  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Sampling:   0%|          | 0/250 [00:00<?, ?it/s]

Steps: 250, CD: 0.009313728660345078, EMD: 0.059037432074546814


  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Steps: 125, CD: 0.00936738308519125, EMD: 0.06267478317022324


  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Sampling:   0%|          | 0/63 [00:00<?, ?it/s]

Steps: 63, CD: 0.00897583831101656, EMD: 0.06076054275035858


  0%|          | 0/26 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

Steps: 32, CD: 0.009257766418159008, EMD: 0.06218022480607033


In [39]:
import os

def save_means(means, steps_to_run, scheduler, idx):
    path = f'../metrics/{"-".join(categories)}/{scheduler}/{"distilled" if distilled else "skip"}/cond/means/'
    os.makedirs(os.path.dirname(path), exist_ok=True)

    filename = f"{path}/means_{idx}.res"
    string = ""
    for cd, emd, steps in means:
        string += f"Steps: {steps:4d}, CD: {cd.item():.6f}, EMD: {emd.item():.6f}\n"
        
    with open(filename, "w") as f:
        f.write(string)
        
    print(f"Saved means to {filename}")
    
save_means(means, steps_to_run, scheduler, 0)

Saved means to ../metrics/airplane/ddpm/distilled/cond/means//means_0.res
