In [1]:
%cd ../src
%load_ext autoreload
%autoreload 2

/home/ubuntu/SPVD_Lightning/src


In [2]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
from tqdm.auto import tqdm
import random

synsetid_to_cate = {
    '02691156': 'airplane', '02958343': 'car', '03001627': 'chair',
   
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}


class Uniform15KPC(Dataset):
    def __init__(self, root_dir, subdirs, tr_sample_size=10000,
                 te_sample_size=10000, split='train', scale=1.,
                 normalize_per_shape=False, random_subsample=False,
                 normalize_std_per_axis=False,
                 all_points_mean=None, all_points_std=None,
                 input_dim=3, load_renders=True):
        self.root_dir = root_dir
        self.split = split
        self.in_tr_sample_size = tr_sample_size
        self.in_te_sample_size = te_sample_size
        self.subdirs = subdirs
        self.scale = scale
        self.random_subsample = random_subsample
        self.input_dim = input_dim

        self.all_cate_mids = []
        self.cate_idx_lst = []
        self.all_points = []
        self.renders = []
        for cate_idx, subd in enumerate(self.subdirs):
            # NOTE: [subd] here is synset id
            sub_path = os.path.join(root_dir, "pointclouds", subd, self.split)
            if not os.path.isdir(sub_path):
                print("Directory missing : %s" % sub_path)
                continue

            all_mids = []
            for x in os.listdir(sub_path):
                if not x.endswith('.npy'):
                    continue
                all_mids.append(os.path.join(self.split, x[:-len('.npy')]))

            # NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
            for mid in tqdm(all_mids):
                # obj_fname = os.path.join(sub_path, x)
                obj_fname = os.path.join(root_dir, "pointclouds", subd, mid + ".npy")
                try:
                    point_cloud = np.load(obj_fname)  # (15k, 3)
                except:
                    continue
                
                if load_renders:
                    render_features = []
                    for view in range(8):
                        render_file = os.path.join(root_dir, 'embed_renders', subd, self.split, obj_fname.split('/')[-1].split('.')[0], f"00{view}_patch_embs.pt")
                        if os.path.exists(render_file):
                            render_features.append(torch.load(render_file, weights_only=True))
                    render_features = torch.stack(render_features, dim=0)
                    self.renders.append(render_features)

                assert point_cloud.shape[0] == 15000
                self.all_points.append(point_cloud[np.newaxis, ...])
                self.cate_idx_lst.append(cate_idx)
                self.all_cate_mids.append((subd, mid))

        # Shuffle the index deterministically (based on the number of examples)
        self.shuffle_idx = list(range(len(self.all_points)))
        random.Random(38383).shuffle(self.shuffle_idx)
        self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
        self.all_points = [self.all_points[i] for i in self.shuffle_idx]
        self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]

        # Normalization
        self.all_points = np.concatenate(self.all_points)  # (N, 15000, 3)
        self.normalize_per_shape = normalize_per_shape
        self.normalize_std_per_axis = normalize_std_per_axis
        if all_points_mean is not None and all_points_std is not None:  # using loaded dataset stats
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
        elif self.normalize_per_shape:  # per shape normalization
            B, N = self.all_points.shape[:2]
            self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
            else:
                self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
        else:  # normalize across the dataset
            self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
            else:
                self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)

        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
        self.train_points = self.all_points[:, :10000]
        self.test_points = self.all_points[:, 10000:]

        self.tr_sample_size = min(10000, tr_sample_size)
        self.te_sample_size = min(5000, te_sample_size)
        print("Total number of data:%d" % len(self.train_points))
        print("Min number of points: (train)%d (test)%d"
              % (self.tr_sample_size, self.te_sample_size))
        assert self.scale == 1, "Scale (!= 1) is deprecated"

    def get_pc_stats(self, idx):
        if self.normalize_per_shape:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)

    def renormalize(self, mean, std):
        self.all_points = self.all_points * self.all_points_std + self.all_points_mean
        self.all_points_mean = mean
        self.all_points_std = std
        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
        self.train_points = self.all_points[:, :10000]
        self.test_points = self.all_points[:, 10000:]

    def __len__(self):
        return len(self.train_points)

    def __getitem__(self, idx):
        tr_out = self.train_points[idx]
        if self.random_subsample:
            tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
        else:
            tr_idxs = np.arange(self.tr_sample_size)
        tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()

        te_out = self.test_points[idx]
        if self.random_subsample:
            te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
        else:
            te_idxs = np.arange(self.te_sample_size)
        te_out = torch.from_numpy(te_out[te_idxs, :]).float()

        renders = self.renders[idx]
        selected_view = np.random.randint(0, renders.shape[0])
        render_features = renders[selected_view]

        m, s = self.get_pc_stats(idx)
        cate_idx = self.cate_idx_lst[idx]
        sid, mid = self.all_cate_mids[idx]

        return {
            'idx': idx,
            'train_points': tr_out,
            'pc': te_out,
            "render-features": render_features,
            "selected-view": selected_view,
            'mean': m, 'std': s, 'cate_idx': cate_idx,
            'sid': sid, 'mid': mid
        }

class ShapeNet15kPointClouds(Uniform15KPC):
    def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
                 categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
                 split='train', scale=1., normalize_per_shape=False,
                 normalize_std_per_axis=False,
                 random_subsample=False,
                 all_points_mean=None, all_points_std=None, load_renders=True):
        self.root_dir = root_dir
        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.tr_sample_size = tr_sample_size
        self.te_sample_size = te_sample_size
        self.cates = categories
        if 'all' in categories:
            self.synset_ids = list(cate_to_synsetid.values())
        else:
            self.synset_ids = [cate_to_synsetid[c] for c in self.cates]

        # assert 'v2' in root_dir, "Only supporting v2 right now."
        self.gravity_axis = 1
        self.display_axis_order = [0, 2, 1]

        super(ShapeNet15kPointClouds, self).__init__(
            root_dir, self.synset_ids,
            tr_sample_size=tr_sample_size,
            te_sample_size=te_sample_size,
            split=split, scale=scale,
            normalize_per_shape=normalize_per_shape,
            normalize_std_per_axis=normalize_std_per_axis,
            random_subsample=random_subsample,
            all_points_mean=all_points_mean, all_points_std=all_points_std,
            input_dim=3, load_renders=load_renders)


def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)
    

def get_datasets(args):
    if args.dataset_type == 'shapenet15k':
        tr_dataset = ShapeNet15kPointClouds(
            categories=args.cates, split='train',
            tr_sample_size=args.tr_max_sample_points,
            te_sample_size=args.te_max_sample_points,
            scale=args.dataset_scale, root_dir=args.data_dir,
            normalize_per_shape=args.normalize_per_shape,
            normalize_std_per_axis=args.normalize_std_per_axis,
            random_subsample=True, load_renders=False)
        te_dataset = ShapeNet15kPointClouds(
            categories=args.cates, split='val',
            tr_sample_size=args.tr_max_sample_points,
            te_sample_size=args.te_max_sample_points,
            scale=args.dataset_scale, root_dir=args.data_dir,
            normalize_per_shape=args.normalize_per_shape,
            normalize_std_per_axis=args.normalize_std_per_axis,
            all_points_mean=tr_dataset.all_points_mean,
            all_points_std=tr_dataset.all_points_std,
        )
    else:
        raise Exception("Invalid dataset type:%s" % args.dataset_type)

    return te_dataset

In [3]:
def get_test_dataset(path, cates = ['chair']):
    # using the same parameters as point flow
    class Args: pass
    args = Args()
    args.data_dir = path
    args.dataset_type = 'shapenet15k'
    args.tr_max_sample_points = 2048
    args.te_max_sample_points = 2048
    args.dataset_scale = 1.
    args.normalize_per_shape = False
    args.normalize_std_per_axis = False
    args.cates = cates

    test_dataset = get_datasets(args)

    return test_dataset

categories = ['airplane']

s = get_test_dataset('../data/ShapeNet', categories)

  0%|          | 0/2832 [00:00<?, ?it/s]

Total number of data:2832
Min number of points: (train)2048 (test)2048


  0%|          | 0/405 [00:00<?, ?it/s]

Total number of data:405
Min number of points: (train)2048 (test)2048


In [4]:
from models.ddpm_unet_cattn import SPVUnet
import torch
import lightning as L
from models.g_spvd import GSPVD

In [5]:
## Hyperparameters
# steps_to_run = [1000, 500, 250, 125, 63, 32, 16, 8, 4, 2]
steps_to_run = [1000]
on_all = True
distilled = False
scheduler = 'ddpm'

# categories = ['car']

In [6]:
from torch.utils.data import DataLoader
from dataloaders.shapenet.shapenet_loader import ShapeNet

path = "../data/ShapeNet"

# test_dataset = ShapeNet(path, "test", 2048, categories, load_renders=True, total=800 if on_all else 5)
test_loader = DataLoader(s, batch_size=64, num_workers=0)

In [7]:
from utils.hyperparams import load_hyperparams

hparams_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/hparams.yaml'

hparams = load_hyperparams(hparams_path)

In [8]:
model_args = {
    'voxel_size' : hparams['voxel_size'],
    'nfs' : hparams['nfs'], 
    'attn_chans' : hparams['attn_chans'], 
    'attn_start' : hparams['attn_start'], 
    'cross_attn_chans' : hparams['cross_attn_chans'], 
    'cross_attn_start' : hparams['cross_attn_start'], 
    'cross_attn_cond_dim' : hparams['cross_attn_cond_dim'],
}

model = SPVUnet(**model_args)
model = GSPVD(model=model)

In [9]:
model = model.cuda().eval()

In [10]:
from my_schedulers.ddpm_scheduler import DDPMSparseScheduler
from my_schedulers.ddim_scheduler import DDIMSparseScheduler
from utils.helper_functions import process_ckpt
from schedulers.factory import create_sparse_scheduler


def get_sched(steps, dist, scheduler):
    
    return create_sparse_scheduler(
        scheduler_strategy=scheduler.upper(),
        beta_min=hparams['beta_min'], 
        beta_max=hparams['beta_max'], 
        n_steps=steps, 
        scheduling_method='linear' if hparams['mode'] == 'linear' else 'warmup',
    )

def get_ckpt(steps, dist, scheduler):
    if dist:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/new/{steps}-steps.ckpt'
    elif scheduler == 'ddim':
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'
    else:
        ckpt_path = f'../checkpoints/distillation/GSPVD/{"-".join(categories)}/1000-steps.ckpt'

    ckpt = torch.load(ckpt_path, weights_only=False)
    ckpt = process_ckpt(ckpt)
    return ckpt

In [11]:
from tqdm.auto import tqdm
from metrics.chamfer_dist import ChamferDistanceL2
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from utils.helper_functions import normalize_to_unit_sphere, standardize, normalize_to_unit_cube

def run_test(steps):
    CD = ChamferDistanceL2()
    
    sched = get_sched(steps, distilled, scheduler)

    ckpt = get_ckpt(steps, distilled, scheduler)
    model.load_state_dict(ckpt)
    model.eval()

    cd_mean = 0
    emd_mean = 0
    cd_mean_norm_sphere = 0
    emd_mean_norm_sphere = 0
    n = 0
    
    # mean = torch.tensor(test_dataset.mean).cuda()
    # std = torch.tensor(test_dataset.std).cuda()
    
    with torch.no_grad():
        for datapoint in tqdm(test_loader):
            # print(list(datapoint.keys()))
            ref_pc = datapoint['pc'].cuda()
            features = datapoint['render-features'].cuda()

            B, N, C = ref_pc.shape
            gen_pc = sched.sample(model=model, bs=B, n_points=N, nf=C, cond_emb=features, mode='conditional').cuda()
            
            # print(ref_pc.device, gen_pc.device)
            # ref_pc = ref_pc * std + mean
            # gen_pc = gen_pc * std + mean
            
            # ref_pc: (B, N, F)
            ref_pc_zero = ref_pc - ref_pc.mean(dim=1, keepdim=True)
            gen_pc_zero = gen_pc - gen_pc.mean(dim=1, keepdim=True)
            
            # print(ref_pc_zero.max(), ref_pc_zero.min())

            # ref_pc_zero.mean: (B, 1, F)
            # ref_pc_zero.std: (B, 1, 1) 
            # ref_pc_zero = ref_pc_zero / ref_pc_zero.std(dim=(1, 2), keepdim=True)
            # gen_pc_zero = gen_pc_zero / gen_pc_zero.std(dim=(1, 2), keepdim=True)
            

            cd_mean += CD(ref_pc_zero, gen_pc_zero).item() * B
            
            # emd_mean += EMD(ref_pc, gen_pc, transpose=False).sum().item()
            
            ref_pc_norm = normalize_to_unit_sphere(ref_pc)
            gen_pc_norm = normalize_to_unit_sphere(gen_pc)

            cd_mean_norm_sphere += CD(ref_pc_norm, gen_pc_norm).item() * B
            # emd_mean_norm_sphere += EMD(ref_pc_norm, gen_pc_norm, transpose=False).sum().item()
            
            n += B
            print(cd_mean / n)
        
    cd_mean /= n
    emd_mean /= n
    
    cd_mean_norm_sphere /= n
    emd_mean_norm_sphere /= n
    
    print(f"Steps: {steps}, CD: {cd_mean}, EMD: {emd_mean} (centered)")
    print(f"Steps: {steps}, CD: {cd_mean_norm_sphere}, EMD: {emd_mean_norm_sphere} (normalized to unit sphere)")
    
    return (cd_mean, emd_mean), (cd_mean_norm_sphere, emd_mean_norm_sphere)
    

In [12]:
import os

def save_means(means, steps):
    path = f'../metrics/{"-".join(categories)}/{scheduler}/{"distilled" if distilled else "skip"}/means/'
    os.makedirs(os.path.dirname(path), exist_ok=True)

    filename = f"{path}/means_{steps}.res"
    string = ""
    for i, ((cd, emd), (cd_norm, emd_norm)) in enumerate(means):
        string += f"Steps: {steps:4d}\n"
        string += f"CD: {cd:.8f} | CD (norm): {cd_norm:.8f}\n"
        string += f"EMD: {emd:.8f} | EMD (norm): {emd_norm:.8f}\n"
        string += "-" * 50 + "\n"
    
    best_cd = min(means, key=lambda x: x[0][0])[0][0]
    best_emd = min(means, key=lambda x: x[0][1])[0][1]
    best_cd_norm = min(means, key=lambda x: x[1][0])[1][0]
    best_emd_norm = min(means, key=lambda x: x[1][1])[1][1]
    
    string += f"Best CD: {best_cd:.8f} | Best CD (norm): {best_cd_norm:.8f}\n"
    string += f"Best EMD: {best_emd:.8f} | Best EMD (norm): {best_emd_norm:.8f}\n"
    
    with open(filename, "w") as f:
        f.write(string)
        
    print(f"Saved means to {filename}")

for steps in steps_to_run:
    means = [run_test(steps) for _ in range(1)]
    # save_means(means, steps)

  0%|          | 0/7 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

0.045626357197761536


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
means