In [2]:
%cd ../src

/home/ubuntu/SPVD_Lightning/src


In [3]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [None]:
## Model Selection
steps = 500
on_all = True
scheduler = 'ddpm'
distilled = False
step_size = 1
conditional = True

categories = ['chair']

In [5]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

Loading (test) renders for chair (03001627):   0%|          | 0/1317 [00:00<?, ?it/s]

In [6]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [7]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [8]:
model = model.cuda().eval()

In [9]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.visualization import display_pointclouds_grid

if scheduler == 'ddim' and distilled:
    sched = DDIMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
        step_size=step_size,
    )
elif distilled:
    sched = DDPMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=hparams['n_steps'],
        mode=hparams['mode'],
    )
else:
    sched = DDPMSparseScheduler(
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        steps=steps, 
        init_steps=steps,
        mode=hparams['mode'],
    )

In [None]:
from utils.helper_functions import process_ckpt

if distilled:
    ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'
elif scheduler == 'ddim' and step_size > 1:
    ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
else:
    ckpt_path = f'../checkpoints/ShapeNet/GSPVD/{"-".join(categories)}/{scheduler}/{steps}-steps.ckpt'

ckpt = torch.load(ckpt_path, weights_only=False)
ckpt = process_ckpt(ckpt)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [11]:
from tqdm.auto import tqdm
from metrics.evaluation_metrics import cham3D

all_ref_pc = []
all_gen_pc = []

mean = torch.tensor(test_loader.dataset.mean).cuda()
std = torch.tensor(test_loader.dataset.std).cuda()

i = 0
for datapoint in tqdm(test_loader):
    i += 1
    if i > 5 and not on_all:
        continue

    ref_pc = datapoint['pc'].cuda()
    features = datapoint['render-features'].cuda() if conditional else None

    B, N, C = ref_pc.shape
    gen_pc = sched.sample(model, B, N, reference=features)

    all_ref_pc.append(ref_pc)
    all_gen_pc.append(gen_pc)

all_ref_pc = torch.cat(all_ref_pc).cuda()
all_gen_pc = torch.cat(all_gen_pc).cuda()

  0%|          | 0/42 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

In [12]:
from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD
from pprint import pprint

In [None]:
results = compute_all_metrics(all_ref_pc, all_gen_pc, batch_size=32)
results = {k: (v.cpu().detach().item()
               if not isinstance(v, float) else v) for k, v in results.items()}

pprint(results)

In [12]:
jsd = JSD(all_gen_pc.cpu().numpy(), all_ref_pc.cpu().numpy())
results['JSD'] = jsd
pprint('JSD: {}'.format(jsd))

'JSD: 0.020793357553577252'


In [None]:
import json
import os
from math import ceil

if on_all:
    if distilled:
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/distilled/'
    elif step_size > 1 and scheduler == 'ddim':
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/skip-{step_size}/'
        step_size = ceil(steps / step_size)
    else:
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/retrained/'

    folder += 'uncond/no-norm' if not conditional else 'cond/no-norm'
    
    os.makedirs(folder, exist_ok=True)

    file = os.path.join(folder, f'{steps}-steps.json')
    with open(file, 'w') as f:
        json.dump(results, f, indent=4)

In [13]:
def normalize_to_unit_sphere(batched_points: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Normalize a batched tensor of 3D points to the unit sphere.
    
    Args:
        batched_points: (B, N, 3) tensor, where B = batch size, N = num points.
        eps: Small value to avoid division by zero.
    
    Returns:
        (B, N, 3) tensor, where all points lie within or on the unit sphere.
    """
    # Center points by subtracting their mean (centroid)
    centroid = torch.mean(batched_points, dim=1, keepdim=True)  # (B, 1, 3)
    centered = batched_points - centroid  # (B, N, 3)

    # Find the maximum distance from the origin for each batch
    max_dist = torch.max(
        torch.sqrt(torch.sum(centered ** 2, dim=-1, keepdim=True)),  # (B, N, 1)
        dim=1, keepdim=True
    ).values  # (B, 1, 1)

    # Normalize by dividing by the maximum distance (+ eps for stability)
    normalized = centered / (max_dist + eps)  # (B, N, 3)

    return normalized

In [14]:
all_gen_pc_norm = normalize_to_unit_sphere(all_gen_pc)
all_ref_pc_norm = normalize_to_unit_sphere(all_ref_pc)

In [15]:
results = compute_all_metrics(all_ref_pc_norm, all_gen_pc_norm, batch_size=32)
results = {k: (v.cpu().detach().item()
               if not isinstance(v, float) else v) for k, v in results.items()}

pprint(results)

  0%|          | 0/1317 [00:00<?, ?it/s]

  0%|          | 0/1317 [00:00<?, ?it/s]

  0%|          | 0/1317 [00:00<?, ?it/s]

{'1-NN-CD-acc': 0.5820045471191406,
 '1-NN-CD-acc_f': 0.45406225323677063,
 '1-NN-CD-acc_t': 0.709946870803833,
 '1-NN-EMD-acc': 0.5364464521408081,
 '1-NN-EMD-acc_f': 0.4836750328540802,
 '1-NN-EMD-acc_t': 0.5892179012298584,
 'CD-mean': 0.015033562667667866,
 'EMD-mean': 0.06812983006238937,
 'lgan_cov-CD': 0.5550493597984314,
 'lgan_cov-EMD': 0.553530752658844,
 'lgan_mmd-CD': 0.005658701993525028,
 'lgan_mmd-EMD': 0.0320662260055542,
 'lgan_mmd_smp-CD': 0.006564137991517782,
 'lgan_mmd_smp-EMD': 0.03450792655348778}


In [16]:
jsd = JSD(all_gen_pc.cpu().numpy(), all_ref_pc.cpu().numpy())
results['JSD'] = jsd
pprint('JSD: {}'.format(jsd))

'JSD: 0.006962009979952555'


In [17]:
import json
import os

if on_all:
    if distilled:
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/distilled/'
    elif step_size > 1 and scheduler == 'ddim':
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/skip-{step_size}/'
    else:
        folder = f'../metrics/{"-".join(categories)}/{scheduler}/retrained/'

    folder += 'uncond/norm' if not conditional else 'cond/norm'
    
    os.makedirs(folder, exist_ok=True)

    file = os.path.join(folder, f'{steps}-steps.json')
    with open(file, 'w') as f:
        json.dump(results, f, indent=4)