In [1]:
%load_ext autoreload
%autoreload 2

import torch
from models.encoders.pointnet import PointNetEncoder, ResNetEncoder

In [6]:
encoder = PointNetEncoder(100, 3)

a = torch.rand(32, 2048, 3)
mu, var = encoder(a)

print(mu.shape)
print(var.shape)

torch.Size([32, 512, 2048])
torch.Size([32, 512, 1])
torch.Size([32, 512])
torch.Size([32, 100])
torch.Size([32, 100])


In [6]:
encoder_resnet = ResNetEncoder(100, 3)
a = torch.rand(32, 2048, 3)
mu, var = encoder_resnet(a)

print(mu.shape)
print(var.shape)

torch.Size([32, 100])
torch.Size([32, 100])


In [8]:
input_tensor = torch.randn(32, 1, 128)  # Replace B with the actual batch size

# Repeat the tensor along the second dimension
repeated_tensor = input_tensor.repeat(1, 2048, 1)  # Bx2048x128
print(repeated_tensor.shape)

# # Reshape the tensor to the desired shape
# reshaped_tensor = repeated_tensor.view(batch_size, 2048, feature_size)


# output_tensor = repeat_and_reshape(input_tensor)
# print(output_tensor.size())  # Prints: torch.Size([B, 2048, 128])


torch.Size([32, 2048, 128])


In [7]:
import os
import math
import argparse
import torch
import torch.utils.tensorboard
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from tqdm.auto import tqdm

from utils.dataset import *
from utils.misc import *
from utils.data import *
from models.vae_gaussian import *
from models.vae_flow import *
from models.flow import add_spectral_norm, spectral_norm_power_iteration
from evaluation import *

class Params:
    def __init__(self):
        self.beta_1 = 0.0001
        self.beta_T = 0.02
        self.categories = ['airplane']
        self.dataset_path = '/data/dongyin/diffusion-point-cloud/data/shapenet.hdf5'
        self.device = 'cuda:2'
        self.diffusion_layer_type = 'squash'
        self.encoder = 'resnet'
        self.end_lr = 0.0001
        self.flexibility = 0.0
        self.kl_weight = 0.001
        self.latent_dim = 256
        self.latent_flow_depth = 14
        self.latent_flow_hidden_dim = 256
        self.log_root = './logs_gen'
        self.logging = True
        self.lr = 0.002
        self.max_grad_norm = 10
        self.max_iters = float('inf')
        self.model = 'flow'
        self.num_samples = 4
        self.num_steps = 100
        self.residual = True
        self.sample_num_points = 2048
        self.scale_mode = 'shape_unit'
        self.sched_end_epoch = 400000
        self.sched_mode = 'linear'
        self.sched_start_epoch = 200000
        self.seed = 2020
        self.spectral_norm = False
        self.tag = None
        self.test_freq = 30000
        self.test_size = 100
        self.train_batch_size = 128
        self.truncate_std = 2.0
        self.val_batch_size = 64
        self.val_freq = 1000
        self.weight_decay = 0

args = Params()
seed_all(args.seed)

train_dset = ShapeNetCore(
    path=args.dataset_path,
    cates=args.categories,
    split='train',
    scale_mode=args.scale_mode,
)
val_dset = ShapeNetCore(
    path=args.dataset_path,
    cates=args.categories,
    split='val',
    scale_mode=args.scale_mode,
)
train_iter = get_data_iterator(DataLoader(
    train_dset,
    batch_size=args.train_batch_size,
    num_workers=0,
))

if args.model == 'gaussian':
    model = GaussianVAE(args).to(args.device)
elif args.model == 'flow':
    model = FlowVAE(args).to(args.device)

if args.spectral_norm:
    add_spectral_norm(model)

optimizer = torch.optim.Adam(model.parameters(), 
    lr=args.lr, 
    weight_decay=args.weight_decay
)
scheduler = get_linear_scheduler(
    optimizer,
    start_epoch=args.sched_start_epoch,
    end_epoch=args.sched_end_epoch,
    start_lr=args.lr,
    end_lr=args.end_lr
)

def validate_inspect():
    z = torch.randn([args.num_samples, args.latent_dim]).to(args.device)
    x = model.sample(z, args.sample_num_points, flexibility=args.flexibility) #, truncate_std=args.truncate_std)

def test():
    ref_pcs = []
    for i, data in enumerate(val_dset):
        if i >= args.test_size:
            break
        ref_pcs.append(data['pointcloud'].unsqueeze(0))
    ref_pcs = torch.cat(ref_pcs, dim=0)

    gen_pcs = []
    for i in tqdm(range(0, math.ceil(args.test_size / args.val_batch_size)), 'Generate'):
        with torch.no_grad():
            z = torch.randn([args.val_batch_size, args.latent_dim]).to(args.device)
            x = model.sample(z, args.sample_num_points, flexibility=args.flexibility)
            gen_pcs.append(x.detach().cpu())
    gen_pcs = torch.cat(gen_pcs, dim=0)[:args.test_size]

    # Denormalize point clouds, all shapes have zero mean.
    # [WARNING]: Do NOT denormalize!
    # ref_pcs *= val_dset.stats['std']
    # gen_pcs *= val_dset.stats['std']

    with torch.no_grad():
        results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.val_batch_size)
        results = {k:v.item() for k, v in results.items()}
        jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
        results['jsd'] = jsd

using resnet encoder.


In [52]:
sample = torch.randn(1, 2048, 3)
ref = torch.randn(1, 2048, 3)

In [None]:
import torch
from itertools import permutations

import torch
import torch.nn.functional as F

def M(C, u, v, reg):
    "Modified cost for logarithmic updates"
    "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
    return (-C + u.unsqueeze(1) + v.unsqueeze(0)) / reg

def lse(A):
    "log-sum-exp"
    return torch.log(torch.exp(A).sum(1, keepdim=True) + 1e-6)  # add 10^-6 to prevent NaN


def sinkhorn(dist_mat, reg, num_iters=100):
    """
    Run the Sinkhorn algorithm.
    
    Args:
    - dist_mat (torch.Tensor): The distance matrix. shape=(n_samples, n_samples)
    - reg (float): The regularization term.
    - num_iters (int): Number of iterations to run the algorithm.

    Returns:
    - P (torch.Tensor): The optimal transport matrix.
    """
    n_samples = dist_mat.shape[0]

    mu = 1. / n_samples * torch.FloatTensor(n_samples).fill_(1)
    nu = 1. / n_samples * torch.FloatTensor(n_samples).fill_(1)
    thresh = 10**(-1)  # stopping criterion
    actual_nits = 0  # to check if algorithm terminates because of threshold or max iterations reached

    u, v, err = 0. * mu, 0. * nu, 0.
    
    for _ in range(num_iters):

        u1 = u  # useful to check the update
        u = reg * (torch.log(mu) - lse(M(dist_mat, u, v, reg)).squeeze()) + u
        v = reg * (torch.log(nu) - lse(M(dist_mat, u, v, reg).t()).squeeze()) + v

        # v = torch.sum(K * u, dim=0)
        # u = 1.0 / torch.matmul(K, v)
        err = (u - u1).abs().sum()

        actual_nits += 1
        if err < thresh:
            break
        
    P = torch.exp(M(dist_mat, u, v, reg))
    
    return P

def emd_approx(sample_batch, reference_batch, reg=0.1):
    """
    Calculate the EMD for each pair of sample and reference examples.

    Args:
    - sample (torch.Tensor): The sample set. shape=(n_samples, n_dims)
    - reference (torch.Tensor): The reference set. shape=(n_samples, n_dims)
    - reg (float): The regularization term.

    Returns:
    - emd (torch.Tensor): The EMD for each pair of examples. shape=(n_samples,)
    """
    batch_size, n_samples, _ = sample_batch.shape
    emd = torch.zeros((batch_size,))
    for b in range(batch_size):
        # Calculate the distance matrix
        sample_exp = sample_batch[b].unsqueeze(1)
        reference_exp = reference_batch[b].unsqueeze(0)
        dist_mat = torch.sum((sample_exp - reference_exp)**2, dim=2)
        # print(sample_exp.shape, reference_exp.shape, dist_mat.shape)

        # Run the Sinkhorn algorithm
        P = sinkhorn(dist_mat, reg)

        # Calculate the EMD
        emd[b] = torch.sum(P * dist_mat)
    
    return torch.mean(emd)

def emd_approx_batch(sample_batch, reference_batch, reg=0.1, num_iters=100):
    """
    Calculate the EMD for each pair of sample and reference examples in a batch using Sinkhorn iterations.

    Args:
    - sample_batch (torch.Tensor): The batch of sample sets. shape=(batch_size, n_samples, n_dims)
    - reference_batch (torch.Tensor): The batch of reference sets. shape=(batch_size, n_samples, n_dims)
    - reg (float): The regularization term.
    - num_iters (int): Number of iterations to run the algorithm.

    Returns:
    - emd (torch.Tensor): The EMD for each pair of examples. shape=(batch_size,)
    """
    batch_size, n_samples, _ = sample_batch.shape

    # Calculate pairwise distances
    x_sq = torch.sum(sample_batch * sample_batch, dim=-1, keepdim=True)
    y_sq = torch.sum(reference_batch * reference_batch, dim=-1, keepdim=True)
    xy = torch.matmul(sample_batch, reference_batch.transpose(-2, -1))
    dist_mat = - 2*xy + y_sq.transpose(-2, -1) + x_sq  # Using (a-b)^2 = a^2 + b^2 - 2ab

    # Run the Sinkhorn algorithm
    K = torch.exp(-dist_mat / reg)
    u = torch.ones(batch_size, n_samples, device=sample_batch.device) / n_samples
    for _ in range(num_iters):
        v = torch.sum(K * u.unsqueeze(-1), dim=1)
        u = 1.0 / torch.sum(K * v.unsqueeze(1), dim=2)

    P = u.unsqueeze(-1) * K * v.unsqueeze(1)

    # Calculate the EMD
    emd = torch.sum(P * dist_mat, dim=[1, 2])

    return emd

print(emd_approx(sample, ref))
print(emd_approx(torch.randn(128, 2048, 3), torch.randn(128, 2048, 3)))

In [55]:
len(val_dset)

607

In [8]:
ckpt = torch.load("./logs_gen/GEN_2023_06_04__23_33_18/ckpt_0.000000_30000.pt")
model.load_state_dict(ckpt['state_dict'])


# validate_inspect()
test()

Generate:   0%|          | 0/2 [00:00<?, ?it/s]

Pairwise EMD CD


Pairwise EMD-CD:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 