In [18]:
%cd ../src
%load_ext autoreload
%autoreload 2

/home/ubuntu/SPVD_Lightning/src
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [29]:
## Hyperparameters
# steps_to_run = [32, 16, 8, 4, 2]
steps_to_run = [4]
# steps_to_run = [500, 250, 125, 63, 32, 16, 8, 4, 2]
on_all = True
distilled = True
scheduler = 'ddim'

categories = ['chair']

In [30]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "val", 2048, categories, load_renders=True,)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=0)

Loading (val) renders for chair (03001627):   0%|          | 0/662 [00:00<?, ?it/s]

In [31]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [32]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [33]:
model = model.cuda().eval()

In [34]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt

def get_sched(steps, dist, scheduler):
    if scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif dist:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, dist, scheduler):
    if dist:
        # ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{steps}-steps.ckpt'
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/combined/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{steps}-steps.ckpt'

    # ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/distillation/GSPVD/chair/combined/intemediate/16-steps/16-steps-epoch=1499.ckpt'
    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [35]:
from tqdm.auto import tqdm
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from utils.helper_functions import normalize_to_unit_sphere, standardize, normalize_to_unit_cube
from schedulers.factory import create_sparse_scheduler

from metrics.rgb2point import chamfer_distance, EMDLoss

emd_loss = EMDLoss()

def run_test(steps):
    CD = ChamferDistanceL2()
    
    sched = get_sched(steps, distilled, scheduler)
    # sched = create_sparse_scheduler() # Chair, Car

    ckpt = get_ckpt(steps, distilled, scheduler)
    model.load_state_dict(ckpt)
    model.eval()

    cd_mean = 0
    n = 0
    
    with torch.no_grad():
        for datapoint in tqdm(test_loader):
            ref_pc = datapoint['pc'].cuda()
            features = datapoint['render-features'].cuda()

            B, N, C = ref_pc.shape
            gen_pc = sched.sample(model, B, N, reference=features)
            
            ref_pc = ref_pc - ref_pc.mean(dim=1, keepdim=True)
            # Point Clouds should have the max distance from the origin equal to 0.64
            r = (ref_pc * ref_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
            #print(f'Max radius: {r.shape}')
            #print(ref_pc.shape)
            ref_pc = ref_pc / r.unsqueeze(-1) * 0.64
            # Shuffle Points in each point cloud of the batch
            ref_pc = ref_pc[:, torch.randperm(ref_pc.shape[1])]
            ref_pc = ref_pc[:, :1024] # Take only 1024 points from each point cloud

            gen_pc = gen_pc - gen_pc.mean(dim=1, keepdim=True)
            # Point Clouds should have the max distance from the origin equal to 0.64
            r = (gen_pc * gen_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
            # print(f'Max radius: {r}')
            gen_pc = gen_pc / r.unsqueeze(-1) * 0.64
            # Shuffle Points in each point cloud of the batch
            gen_pc = gen_pc[:, torch.randperm(gen_pc.shape[1])]
            gen_pc = gen_pc[:, :1024]

            for g, r in tqdm(zip(ref_pc, gen_pc), leave=False):
                g = g.detach().cpu()
                r = r.detach().cpu()
                cd_mean += chamfer_distance(g, r, direction='bi') / 2
            
            # emd_mean += emd_loss(ref_pc, gen_pc)
            
            n += B
            # print(f"CD: {cd_mean / n}")
        
    cd_mean /= n
       
    print(f"Steps: {steps}, CD: {cd_mean}")
    
    return cd_mean

In [36]:
import os
import numpy as np

def save_means(cds):
    path = f'../metrics/{"-".join(categories)}'
    os.makedirs(os.path.dirname(path), exist_ok=True)

    # filename = f"{path}/{'distilled' if distilled else 'skip'}.res"
    filename = f"{path}/combined.res"
    if scheduler == 'ddpm':
        filename = f"{path}/ddpm.res"
    string = ""
    for steps in sorted(cds.keys(), reverse=True):
        cd = np.array(cds[steps])
        string += f"{steps}, {cd.mean()}, {cd.std()}, {cd.min()}\n"
        
    with open(filename, "w") as f:
        f.write(string)
        
    print(f"Saved means to {filename}")

In [37]:
from collections import defaultdict

cds = defaultdict(list)
for steps in steps_to_run:
    try:
        cds[steps].append([run_test(steps) for _ in range(3)])
    except KeyboardInterrupt:
        break
else:
    save_means(cds)
    ...
        

  0%|          | 0/21 [00:00<?, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Steps: 4, CD: 0.053195813594169666


  0%|          | 0/21 [00:00<?, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
save_means(cds)

Saved means to ../metrics/chair/combined.res
