In [1]:
%cd ../src
%load_ext autoreload
%autoreload 2

/home/ubuntu/SPVD_Lightning/src


In [2]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [12]:
## Hyperparameters
steps_to_run = [125]
on_all = True
scheduler = 'ddpm'
distilled = True
step_sizes = [1, 2, 4, 8, 16, 32]

categories = ['airplane']

In [13]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=800 if on_all else 5)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=0)

Loading (test) renders for airplane (02691156):   0%|          | 0/808 [00:00<?, ?it/s]

In [14]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [15]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [16]:
model = model.cuda().eval()

In [17]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps, step_size=1):
    if scheduler == 'ddim' and distilled:
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
            step_size=step_size,
        )
    elif distilled:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, conditional, step_size=1):
    if distilled:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{"cond" if conditional else "uncond"}/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'

    ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/distillation/GSPVD/airplane/cond/125-steps/intemediate/125-steps-epoch=499.ckpt'
    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [18]:
from tqdm.auto import tqdm
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from utils.helper_functions import normalize_to_unit_sphere, standardize, normalize_to_unit_cube

def run_test(steps, step_size, scheduler):
    CD = ChamferDistanceL2()
    if distilled:
        ckpt = get_ckpt(steps, conditional=True)
    else:
        ckpt = get_ckpt(steps, conditional=True, step_size=step_size)
        
    model.load_state_dict(ckpt, strict=False)
    model.eval()

    sched = get_sched(steps, step_size=step_size)

    cd_mean = 0
    emd_mean = 0
    cd_mean_norm_sphere = 0
    emd_mean_norm_sphere = 0
    cd_mean_standard = 0
    emd_mean_standard = 0
    cd_mean_norm_cube = 0
    emd_mean_norm_cube = 0
    cd_mean_norm_sphere
    n = 0
    
    for datapoint in tqdm(test_loader):
        ref_pc = datapoint['pc'].cuda()
        features = datapoint['render-features'].cuda()

        B, N, C = ref_pc.shape
        gen_pc = sched.sample(model, B, N, reference=features, guidance_scale=1)
        
        ref_pc_zero = ref_pc - ref_pc.mean(dim=1, keepdim=True)
        gen_pc_zero = gen_pc - gen_pc.mean(dim=1, keepdim=True)
        
        cd_mean += CD(ref_pc_zero, gen_pc_zero).item() * B
        emd_mean += EMD(ref_pc_zero, gen_pc_zero, transpose=False).sum().item()
        
        ref_pc_norm = normalize_to_unit_sphere(ref_pc)
        gen_pc_norm = normalize_to_unit_sphere(gen_pc)

        cd_mean_norm_sphere += CD(ref_pc_norm, gen_pc_norm).item() * B
        emd_mean_norm_sphere += EMD(ref_pc_norm, gen_pc_norm, transpose=False).sum().item()
        
        ref_pc_standard = standardize(ref_pc)
        gen_pc_standard = standardize(gen_pc)
        
        cd_mean_standard += CD(ref_pc_standard, gen_pc_standard).item() * B
        emd_mean_standard += EMD(ref_pc_standard, gen_pc_standard, transpose=False).sum().item()
        
        ref_pc_norm_cube = normalize_to_unit_cube(ref_pc)
        gen_pc_norm_cube = normalize_to_unit_cube(gen_pc)
        
        cd_mean_norm_cube += CD(ref_pc_norm_cube, gen_pc_norm_cube).item() * B
        emd_mean_norm_cube += EMD(ref_pc_norm_cube, gen_pc_norm_cube, transpose=False).sum().item()
    
        n += B
        
    cd_mean /= n
    emd_mean /= n
    
    cd_mean_norm_sphere /= n
    emd_mean_norm_sphere /= n
    
    cd_mean_standard /= n
    emd_mean_standard /= n
    
    cd_mean_norm_cube /= n
    emd_mean_norm_cube /= n
    
    steps = len(sched.t_steps)
    
    print(f"Steps: {steps}, CD: {cd_mean}, EMD: {emd_mean} (centered)")
    print(f"Steps: {steps}, CD: {cd_mean_norm_sphere}, EMD: {emd_mean_norm_sphere} (normalized to unit sphere)")
    print(f"Steps: {steps}, CD: {cd_mean_standard}, EMD: {emd_mean_standard} (standardized)")
    print(f"Steps: {steps}, CD: {cd_mean_norm_cube}, EMD: {emd_mean_norm_cube} (normalized to unit cube)")
    return steps, cd_mean_norm_sphere, emd_mean_norm_sphere, cd_mean_standard, emd_mean_standard, cd_mean_norm_cube, emd_mean_norm_cube
    
# means = [run_test(steps, step_size, scheduler) for steps in steps_to_run]
    

In [19]:
import os

def save_means(means, steps_to_run, scheduler, idx):
    path = f'../metrics/{"-".join(categories)}/{scheduler}/{"distilled" if distilled else "skip"}/cond/means/'
    os.makedirs(os.path.dirname(path), exist_ok=True)

    filename = f"{path}means_{idx}.res"
    string = ""
    for steps, cd, emd, cd_standard, emd_standard, cd_norm_cube, emd_norm_cube in means:
        string += f"Steps: {steps:4d}\n"
        string += f"CD: {cd:.4f} | CD (standard): {cd_standard:.4f} | CD (norm cube): {cd_norm_cube:.4f}\n"
        string += f"EMD: {emd:.4f} | EMD (standard): {emd_standard:.4f} | EMD (norm cube): {emd_norm_cube:.4f}\n"
        string += "-" * 50 + "\n"
        
    with open(filename, "w") as f:
        f.write(string)
        
    print(f"Saved means to {filename}")

for i in range(1):
    if scheduler == 'ddim' and not distilled:
        means = [run_test(1000, step_size, scheduler) for step_size in step_sizes]
    else:
        means = [run_test(steps, 1, scheduler) for steps in steps_to_run]
    save_means(means, steps_to_run, scheduler, i)

  0%|          | 0/25 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Sampling:   0%|          | 0/125 [00:00<?, ?it/s]

Steps: 125, CD: 0.1284264701604843, EMD: 0.823938045501709 (centered)
Steps: 125, CD: 0.01114748727530241, EMD: 0.06499670252203942 (normalized to unit sphere)
Steps: 125, CD: 0.20822574734687804, EMD: 0.8258842158317566 (standardized)
Steps: 125, CD: 0.014033838286995887, EMD: 0.07392967700958251 (normalized to unit cube)
Saved means to ../metrics/airplane/ddpm/distilled/cond/means/means_0.res
