In [20]:
%cd ../src
%load_ext autoreload
%autoreload 2

/home/ubuntu/SPVD_Lightning/src
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [22]:
## Hyperparameters
steps_to_run = [500, 250, 125, 63, 32, 16, 8, 4, 2]
on_all = True
distilled = True
scheduler = 'ddim'

categories = ['chair']

In [23]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=800 if on_all else 5)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=0)

Loading (test) renders for chair (03001627):   0%|          | 0/1317 [00:00<?, ?it/s]

In [24]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [25]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [26]:
model = model.cuda().eval()

In [27]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps, dist, scheduler):
    if scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif dist:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, dist, scheduler):
    if dist:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/new/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [28]:
from tqdm.auto import tqdm
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from utils.helper_functions import normalize_to_unit_sphere, standardize, normalize_to_unit_cube

def run_test(steps):
    CD = ChamferDistanceL2()
    
    sched = get_sched(steps, distilled, scheduler)

    ckpt = get_ckpt(steps, distilled, scheduler)
    model.load_state_dict(ckpt)
    model.eval()

    cd_mean = 0
    emd_mean = 0
    cd_mean_norm_sphere = 0
    emd_mean_norm_sphere = 0
    n = 0
    
    for datapoint in tqdm(test_loader):
        ref_pc = datapoint['pc'].cuda()
        features = datapoint['render-features'].cuda()

        B, N, C = ref_pc.shape
        gen_pc = sched.sample(model, B, N, reference=features, guidance_scale=1)
        
        ref_pc_zero = ref_pc - ref_pc.mean(dim=1, keepdim=True)
        gen_pc_zero = gen_pc - gen_pc.mean(dim=1, keepdim=True)
        
        cd_mean += CD(ref_pc, gen_pc).item() * B
        emd_mean += EMD(ref_pc, gen_pc, transpose=False).sum().item()
        
        ref_pc_norm = normalize_to_unit_sphere(ref_pc)
        gen_pc_norm = normalize_to_unit_sphere(gen_pc)

        cd_mean_norm_sphere += CD(ref_pc_norm, gen_pc_norm).item() * B
        emd_mean_norm_sphere += EMD(ref_pc_norm, gen_pc_norm, transpose=False).sum().item()
        
        n += B
        
    cd_mean /= n
    emd_mean /= n
    
    cd_mean_norm_sphere /= n
    emd_mean_norm_sphere /= n
    
    print(f"Steps: {steps}, CD: {cd_mean}, EMD: {emd_mean} (centered)")
    print(f"Steps: {steps}, CD: {cd_mean_norm_sphere}, EMD: {emd_mean_norm_sphere} (normalized to unit sphere)")
    
    return (cd_mean, emd_mean), (cd_mean_norm_sphere, emd_mean_norm_sphere)
    

In [None]:
import os

def save_means(means, steps):
    path = f'../metrics/{"-".join(categories)}/{scheduler}/{"distilled" if distilled else "skip"}/means/'
    os.makedirs(os.path.dirname(path), exist_ok=True)

    filename = f"{path}/means_{steps}.res"
    string = ""
    for i, ((cd, emd), (cd_norm, emd_norm)) in enumerate(means):
        string += f"Steps: {steps:4d}\n"
        string += f"CD: {cd:.8f} | CD (norm): {cd_norm:.8f}\n"
        string += f"EMD: {emd:.8f} | EMD (norm): {emd_norm:.8f}\n"
        string += "-" * 50 + "\n"
    
    best_cd = min(means, key=lambda x: x[0][0])[0][0]
    best_emd = min(means, key=lambda x: x[0][1])[0][1]
    best_cd_norm = min(means, key=lambda x: x[1][0])[1][0]
    best_emd_norm = min(means, key=lambda x: x[1][1])[1][1]
    
    string += f"Best CD: {best_cd:.8f} | Best CD (norm): {best_cd_norm:.8f}\n"
    string += f"Best EMD: {best_emd:.8f} | Best EMD (norm): {best_emd_norm:.8f}\n"
    
    with open(filename, "w") as f:
        f.write(string)
        
    print(f"Saved means to {filename}")

for steps in steps_to_run:
    means = [run_test(steps) for _ in range(5)]
    save_means(means, steps)

  0%|          | 0/25 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
means