In [None]:
%cd ../src
%load_ext autoreload
%autoreload 2

In [None]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [None]:
## Hyperparameters
# steps_to_run = [1000, 500, 250, 125, 63, 32, 16, 8, 4, 2]
steps_to_run = [1000]
on_all = True
distilled = False
scheduler = 'ddim'

categories = ['car']

In [None]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [None]:
from dataloaders.giannis_shapenet import ShapeNet15kPointCloudsViTEmbs, ShapeNet15kPointClouds
from torchsparse.utils.collate import sparse_collate_fn
from torch.utils.data import DataLoader

dataset_path = "../data/ShapeNet/pointclouds"
emb_path = "../data/ShapeNet/embed_renders"

dataset = ShapeNet15kPointCloudsViTEmbs(
    dataset_path, 
    emb_path,
    split='test',
    categories=categories, 
    tr_sample_size=2048, 
    random_subsample=False,
)
dataloader = DataLoader(dataset, 32, shuffle=False, drop_last=False, collate_fn = sparse_collate_fn)

In [None]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [None]:
model = model.cuda().eval()

In [None]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt
from schedulers.factory import create_sparse_scheduler


def get_sched(steps, dist, scheduler):
    if scheduler == 'ddim':
        sched = DDIMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    elif dist:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps, 
            init_steps=hparams['n_steps'],
            mode=hparams['mode'],
        )
    else:
        sched = DDPMSparseScheduler(
            beta_min=hparams['beta_min'], 
            beta_max=hparams['beta_max'], 
            steps=steps,
            init_steps=steps,
            mode=hparams['mode'],
        )
    return sched

def get_ckpt(steps, dist, scheduler):
    if dist:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/new/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [None]:
from tqdm.auto import tqdm
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from utils.helper_functions import normalize_to_unit_sphere, standardize, normalize_to_unit_cube
from schedulers.factory import create_sparse_scheduler

from metrics.rgb2point import chamfer_distance, EMDLoss

emd_loss = EMDLoss()

def run_test(steps):
    CD = ChamferDistanceL2()
    
    sched = get_sched(steps, distilled, scheduler)
    # sched = create_sparse_scheduler() # Chair, Car

    ckpt = get_ckpt(steps, distilled, scheduler)
    model.load_state_dict(ckpt)
    model.eval()

    cd_mean = 0
    emd_mean = 0
    n = 0
    
    
    with torch.no_grad():
        for datapoint in tqdm(dataloader):
            ref_pc = datapoint['train_points'].cuda()
            features = datapoint['vit_emb'].cuda()

            B, N, C = ref_pc.shape
            gen_pc = sched.sample(model, B, N, reference=features)
            # gen_pc = sched.sample(model=model, bs=B, n_points=N, nf=C, cond_emb=features, mode='conditional').cuda()
            
            ref_pc = ref_pc - ref_pc.mean(dim=1, keepdim=True)
            # Point Clouds should have the max distance from the origin equal to 0.64
            r = (ref_pc * ref_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
            #print(f'Max radius: {r.shape}')
            #print(ref_pc.shape)
            ref_pc = ref_pc / r.unsqueeze(-1) * 0.64
            # Shuffle Points in each point cloud of the batch
            ref_pc = ref_pc[:, torch.randperm(ref_pc.shape[1])]
            ref_pc = ref_pc[:, :1024] # Take only 1024 points from each point cloud

            gen_pc = gen_pc - gen_pc.mean(dim=1, keepdim=True)
            # Point Clouds should have the max distance from the origin equal to 0.64
            r = (gen_pc * gen_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
            # print(f'Max radius: {r}')
            gen_pc = gen_pc / r.unsqueeze(-1) * 0.64
            # Shuffle Points in each point cloud of the batch
            gen_pc = gen_pc[:, torch.randperm(gen_pc.shape[1])]
            gen_pc = gen_pc[:, :1024]

            for g, r in tqdm(zip(ref_pc, gen_pc), leave=False):
                g = g.detach().cpu()
                r = r.detach().cpu()
                cd_mean += chamfer_distance(g, r, direction='bi') / 2
            
            emd_mean += emd_loss(ref_pc, gen_pc)
            
            n += B
            print(f"CD: {cd_mean / n}")
        
    cd_mean /= n
    emd_mean /= n
       
    print(f"Steps: {steps}, CD: {cd_mean}")
    
    return (cd_mean, emd_mean)

In [None]:
for steps in steps_to_run:
    means = [run_test(steps) for _ in range(1)]
    # save_means(means, steps)

In [None]:
categories