In [None]:
%cd ../src
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os

# Get the notebook's directory using os.getcwd() or pathlib
notebook_dir = os.getcwd()

# Calculate the main directory (notebooks are in a "notebooks" folder)
main_dir = os.path.abspath(os.path.join(notebook_dir, ".."))

# Add the main directory to sys.path
if main_dir not in sys.path:
    sys.path.insert(0, main_dir)

In [None]:
import torch
from torch.utils.data import DataLoader
from dataloaders.giannis_shapenet import ShapeNet15kPointClouds, ShapeNet15kPointCloudsViTEmbs
from dataloaders.transforms import DDPMNoisify, TorchSparseVoxelize
from torchsparse.utils.collate import sparse_collate_fn

from models.ddpm_unet_cattn import SPVUnet
from models.g_spvd import GSPVD

from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD

from schedulers.factory import create_sparse_scheduler

from tqdm import tqdm

dataset_path = "../data/ShapeNet/pointclouds"
embedding_path = "../data/ShapeNet/embed_renders"

In [None]:
categories = ['car']
# categories = ['airplane']
# categories = ['chair']

In [None]:
# 1. Load the validation dataset
dataset = ShapeNet15kPointCloudsViTEmbs(dataset_path, embedding_path,
                                        split='val', categories=categories, tr_sample_size=2048, random_subsample=False)
dataloader = DataLoader(dataset, 32, shuffle=False, drop_last=False, collate_fn = sparse_collate_fn)

In [None]:
from utils.helper_functions import process_ckpt
import torch

# 2. Load the model
spv_unet = SPVUnet(
        voxel_size=0.1,
        nfs=[32,64,128,256],
        attn_chans=8,
        attn_start=3,
        cross_attn_chans=8,
        cross_attn_start=2,
        cross_attn_cond_dim=768
    )

model = GSPVD(
        model=spv_unet
    )

# ckpt_path = ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
# ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/ShapeNet/GSPVD/airplane/ddpm/1000-steps.ckpt'
# ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/ImageGuidedSPVD/version_3/checkpoints/epoch=4999-step=385000.ckpt' # Car
ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/distillation/GSPVD/car/intemediate/500-steps/500-steps-epoch=249.ckpt' # Car
# ckpt_path = '/home/ubuntu/SPVD_Lightning/checkpoints/ImageGuidedSPVD/version_0/checkpoints/epoch=2499-step=362500.ckpt' # Chair

ckpt = torch.load(ckpt_path, weights_only=False)
ckpt = process_ckpt(ckpt)
model.load_state_dict(ckpt)
# model.load_checkpoint(ckpt_path)

In [None]:
model = model.cuda().eval()

In [None]:
scheduler = create_sparse_scheduler() # Chair, Car
# scheduler = create_sparse_scheduler(beta_min=1e-5, beta_max=0.008, scheduling_method='warmup') #Airplane

In [None]:

from metrics.rgb2point import EMDLoss, chamfer_distance
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD

emd_loss_fn = EMDLoss()

cd_loss = 0
emd_loss = 0
cd_1 = 0
emd_1 = 0
n_samples = 0

CD = ChamferDistanceL2()

with torch.no_grad():
    for batch in tqdm(dataloader):
        gt_pc = batch['train_points'].cuda()
        cond_emb = batch['vit_emb'].cuda()
        B = gt_pc.shape[0]
        gen_pc = scheduler.sample(model, B, cond_emb = cond_emb, mode='conditional').cuda()
        # Center Point Clouds
        gt_pc = gt_pc - gt_pc.mean(dim=1, keepdim=True)
        #gt_pc = gt_pc / gt_pc.std(dim=1, keepdim=True)
        # Point Clouds should have the max distance from the origin equal to 0.64
        r = (gt_pc * gt_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
        #print(f'Max radius: {r.shape}')
        #print(gt_pc.shape)
        gt_pc = gt_pc / r.unsqueeze(-1) * 0.64
        # Shuffle Points in each point cloud of the batch
        gt_pc = gt_pc[:, torch.randperm(gt_pc.shape[1])]
        gt_pc = gt_pc[:, :1024] # Take only 1024 points from each point cloud
        gen_pc = gen_pc - gen_pc.mean(dim=1, keepdim=True)
        #gen_pc = gen_pc / gen_pc.std(dim=1, keepdim=True)
        # Point Clouds should have the max distance from the origin equal to 0.64
        r = (gen_pc * gen_pc).sum(dim=-1).sqrt().max(dim=1, keepdim=True)[0]
        # print(f'Max radius: {r}')
        gen_pc = gen_pc / r.unsqueeze(-1) * 0.64
        # Shuffle Points in each point cloud of the batch
        gen_pc = gen_pc[:, torch.randperm(gen_pc.shape[1])]
        gen_pc = gen_pc[:, :1024]

        cd_1 += CD(gt_pc, gen_pc).item() * B
        emd_1 += EMD(gt_pc, gen_pc, transpose=False).sum().item()

        for gpc, gtpc in zip(gen_pc, gt_pc):
            cd = chamfer_distance(gpc.cpu().numpy(), gtpc.cpu().numpy(), metric="l2", direction="x_to_y")
            cd_loss += cd
            emd_loss += emd_loss_fn(gpc.unsqueeze(0), gtpc.unsqueeze(0)).item()
            n_samples += 1

        # # Reconstruction loss
        # cd_loss += CD(gen_pc, gt_pc).item() * B # Returns mean batch loss
        # emd_loss += EMD(gen_pc, gt_pc, transpose=False).sum().item() # Returns per element loss
        # n_samples += B

        print(f'So far - Chamfer Distance: {cd_loss / n_samples} | EMD: {emd_loss / n_samples}')
        print(f"CD: {cd_1 / n_samples}, EMD: {emd_1 / n_samples}")


cd_loss = cd_loss / n_samples
emd_loss = emd_loss / n_samples

print(f'Chamfer Distance: {cd_loss} | EMD: {emd_loss}')
