In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import trimesh
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import os
import random
from copy import copy
import h5py
from tqdm.auto import tqdm

In [2]:
from models.pvcnn2 import Voxelization
from third_party.pvcnn.functional.devoxelization import trilinear_devoxelize

Using /mnt/fsx-home/xingchenliu/.cache/torch_extensions/py38_cu116 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /mnt/fsx-home/xingchenliu/.cache/torch_extensions/py38_cu116/_pvcnn_backend/build.ninja...
Building extension module _pvcnn_backend...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module _pvcnn_backend...


Dataset:

In [3]:
#borrow from https://github.com/luost26/diffusion-point-cloud/blob/1e30d48d018820fbc7c67c8b3190215bd41878e4/utils/dataset.py

synsetid_to_cate = {
    '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
    '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
    '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
    '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
    '02954340': 'cap', '02958343': 'car', '03001627': 'chair',
    '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
    '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
    '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
    '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
    '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
    '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
    '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
    '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
    '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
    '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
    '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
    '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
    '04554684': 'washer', '02992529': 'cellphone',
    '02843684': 'birdhouse', '02871439': 'bookshelf',
    # '02858304': 'boat', no boat in our dataset, merged into vessels
    # '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}


class ShapeNetCore(Dataset):

    GRAVITATIONAL_AXIS = 1
    
    def __init__(self, path, cates, split, scale_mode, transform=None):
        super().__init__()
        assert isinstance(cates, list), '`cates` must be a list of cate names.'
        assert split in ('train', 'val', 'test')
        assert scale_mode is None or scale_mode in ('global_unit', 'shape_unit', 'shape_bbox', 'shape_half', 'shape_34')
        self.path = path
        if 'all' in cates:
            cates = cate_to_synsetid.keys()
        self.cate_synsetids = [cate_to_synsetid[s] for s in cates]
        self.cate_synsetids.sort()
        self.split = split
        self.scale_mode = scale_mode
        self.transform = transform

        self.pointclouds = []
        self.stats = None

        self.get_statistics()
        self.load()

    def get_statistics(self):

        basename = os.path.basename(self.path)
        dsetname = basename[:basename.rfind('.')]
        stats_dir = os.path.join(os.path.dirname(self.path), dsetname + '_stats')
        os.makedirs(stats_dir, exist_ok=True)

        if len(self.cate_synsetids) == len(cate_to_synsetid):
            stats_save_path = os.path.join(stats_dir, 'stats_all.pt')
        else:
            stats_save_path = os.path.join(stats_dir, 'stats_' + '_'.join(self.cate_synsetids) + '.pt')
        if os.path.exists(stats_save_path):
            self.stats = torch.load(stats_save_path)
            return self.stats

        with h5py.File(self.path, 'r') as f:
            pointclouds = []
            for synsetid in self.cate_synsetids:
                for split in ('train', 'val', 'test'):
                    pointclouds.append(torch.from_numpy(f[synsetid][split][...]))

        all_points = torch.cat(pointclouds, dim=0) # (B, N, 3)
        B, N, _ = all_points.size()
        mean = all_points.view(B*N, -1).mean(dim=0) # (1, 3)
        std = all_points.view(-1).std(dim=0)        # (1, )

        self.stats = {'mean': mean, 'std': std}
        torch.save(self.stats, stats_save_path)
        return self.stats

    def load(self):

        def _enumerate_pointclouds(f):
            for synsetid in self.cate_synsetids:
                cate_name = synsetid_to_cate[synsetid]
                for j, pc in enumerate(f[synsetid][self.split]):
                    yield torch.from_numpy(pc), j, cate_name
        
        with h5py.File(self.path, mode='r') as f:
            for pc, pc_id, cate_name in _enumerate_pointclouds(f):

                if self.scale_mode == 'global_unit':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = self.stats['std'].reshape(1, 1)
                elif self.scale_mode == 'shape_unit':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1)
                elif self.scale_mode == 'shape_half':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1) / (0.5)
                elif self.scale_mode == 'shape_34':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1) / (0.75)
                elif self.scale_mode == 'shape_bbox':
                    pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
                    pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
                    shift = ((pc_min + pc_max) / 2).view(1, 3)
                    scale = (pc_max - pc_min).max().reshape(1, 1) / 2
                else:
                    shift = torch.zeros([1, 3])
                    scale = torch.ones([1, 1])

                pc = (pc - shift) / scale

                self.pointclouds.append({
                    'pointcloud': pc,
                    'cate': cate_name,
                    'id': pc_id,
                    'shift': shift,
                    'scale': scale
                })

        # Deterministically shuffle the dataset
        self.pointclouds.sort(key=lambda data: data['id'], reverse=False)
        random.Random(2020).shuffle(self.pointclouds)

    def __len__(self):
        return len(self.pointclouds)

    def __getitem__(self, idx):
        data = {k:v.clone() if isinstance(v, torch.Tensor) else copy(v) for k, v in self.pointclouds[idx].items()}
        if self.transform is not None:
            data = self.transform(data)
        return data

In [4]:
#https://github.com/luost26/diffusion-point-cloud/blob/1e30d48d018820fbc7c67c8b3190215bd41878e4/utils/data.py
def get_train_val_test_datasets(dataset, train_ratio, val_ratio):
    assert (train_ratio + val_ratio) <= 1
    train_size = int(len(dataset) * train_ratio)
    val_size = int(len(dataset) * val_ratio)
    test_size = len(dataset) - train_size - val_size
    
    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
    return train_set, val_set, test_set


def get_train_val_test_loaders(dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers):
    train_set, val_set, test_set = get_train_val_test_datasets(dataset, train_ratio, val_ratio)

    train_loader = DataLoader(train_set, train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader


def get_data_iterator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, data in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

In [5]:
# Parameters
def get_train_val_test_datasets(dataset, train_ratio, val_ratio):
    assert (train_ratio + val_ratio) <= 1
    train_size = int(len(dataset) * train_ratio)
    val_size = int(len(dataset) * val_ratio)
    test_size = len(dataset) - train_size - val_size
    
    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
    return train_set, val_set, test_set


def get_train_val_test_loaders(dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers):
    train_set, val_set, test_set = get_train_val_test_datasets(dataset, train_ratio, val_ratio)

    train_loader = DataLoader(train_set, train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader


def get_data_iterator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, data in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

# If you want the infinite loop, use get_data_iterator instead:
# train_iter = get_data_iterator(train_loader)

Visualize the dataset

In [6]:

def visualize_pointcloud(points, title=""):
    """Visualize a single point cloud."""
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=5)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    ax.set_title(title)
    plt.show()

def visualize_dataset(dataset, num_samples=5):
    """Visualize random samples from the dataset."""
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    for idx in indices:
        sample = dataset[idx]
        points = sample['pointcloud'].numpy()
        title = sample['cate']
        visualize_pointcloud(points, title=title)


Encoder:

In [7]:
class VariationalVoxelNetEncoder(nn.Module):
    def __init__(self, num_points=2048, latent_size=256, resolution=32, normalize=True, eps=1e-8):
        super(VariationalVoxelNetEncoder, self).__init__()

        # Voxelization layer
        self.voxelizer = Voxelization(resolution, normalize, eps)

        # Voxel convolution layers
        self.voxel_conv1 = nn.Conv3d(1, 64, 3, stride=2, padding=1)  # assuming single channel voxel input
        self.voxel_conv2 = nn.Conv3d(64, 128, 3, stride=2, padding=1)
        self.voxel_conv3 = nn.Conv3d(128, 256, 3, stride=2, padding=1)

        # Set abstraction layers (from PVCNN's sa_blocks)
        self.sa_blocks = [
            ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
            ((64, 3, 16), (256, 0.2, 32, (64, 128))),
            ((128, 3, 8), (64, 0.4, 32, (128, 256))),
            (None, (16, 0.8, 32, (128, 128, 128)))
        ]
        self.sa_layers = nn.ModuleList([create_sa_module(config) for config in self.sa_blocks])

        # Fully connected layers to produce mu, logvar
        self.fc_mu = nn.Linear(256, latent_size)  # adjusted based on last SA layer
        self.fc_logvar = nn.Linear(256, latent_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Assuming x has shape [B, D, N]
        features = x[:, :-3, :]
        coords = x[:, -3:, :]

        # Voxelization
        voxelized_features, norm_coords = self.voxelizer(features, coords)

        # Voxel convolution layers
        voxel_out = F.relu(self.voxel_conv1(voxelized_features))
        voxel_out = F.relu(self.voxel_conv2(voxel_out))
        voxel_out = F.relu(self.voxel_conv3(voxel_out))

        # Reshape voxel_out to [B, C, R*R*R] for devoxelization
        voxel_out = voxel_out.view(voxel_out.size(0), voxel_out.size(1), -1)

        # Devoxelization
        point_cloud = trilinear_devoxelize(voxel_out, norm_coords, self.voxelizer.r)

        # Set abstraction layers
        for sa in self.sa_layers:
            point_cloud = sa(point_cloud)

        # Flatten and fully connected to produce mu, logvar
        point_cloud = point_cloud.view(point_cloud.size(0), -1)
        mu = self.fc_mu(point_cloud)
        logvar = self.fc_logvar(point_cloud)

        # Reparameterization trick to get z
        z = self.reparameterize(mu, logvar)

        # Store KL divergence for loss computation
        self.kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return z



Decoder:

In [8]:
class VoxelNetDecoder(nn.Module):
    def __init__(self, num_points=2048, latent_size=256, resolution=32, normalize=True, eps=1e-8):
        super(VoxelNetDecoder, self).__init__()

        # Fully connected layer to expand the latent space
        self.fc_expand = nn.Linear(latent_size, 256)  # expand latent size to a suitable dimension
        
        # Feature Propagation layers
        self.fp_blocks = [
            ((128, 128), (128, 3, 8)),
            ((128, 128), (128, 3, 8)),
            ((128, 128), (128, 2, 16)),
            ((128, 128, 64), (64, 2, 32))
        ]
        self.fp_layers = nn.ModuleList([create_fp_module(config) for config in self.fp_blocks])  # assumes a helper function to create fp layers
        
        # Voxel deconvolution layers
        self.voxel_deconv3 = nn.ConvTranspose3d(256, 128, 3, stride=2, padding=1)
        self.voxel_deconv2 = nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1)
        self.voxel_deconv1 = nn.ConvTranspose3d(64, 1, 3, stride=2, padding=1)  # single channel voxel output
        
        # Voxelization layer
        self.voxelizer = Voxelization(resolution, normalize, eps)

    def forward(self, z):
        # Expand the latent representation
        expanded = F.relu(self.fc_expand(z))
        
        # Reshape it to be processed by FP layers
        point_cloud = expanded.view(expanded.size(0), 128, 2)  # assuming 128 features and 2 points
        
        # Feature propagation layers
        for fp in self.fp_layers:
            point_cloud = fp(point_cloud)

        # Assuming point_cloud is of shape [B, D, N]
        features = point_cloud[:, :-3, :]
        coords = point_cloud[:, -3:, :]

        # Voxelization
        voxelized_features, norm_coords = self.voxelizer(features, coords)
        
        # Voxel deconvolution layers
        voxel_out = F.relu(self.voxel_deconv3(voxelized_features))
        voxel_out = F.relu(self.voxel_deconv2(voxel_out))
        voxel_out = self.voxel_deconv1(voxel_out)  # may need an activation function
        
        # Reshape voxel_out to [B, C, R*R*R] for devoxelization
        voxel_out = voxel_out.view(voxel_out.size(0), voxel_out.size(1), -1)

        # Devoxelization
        reconstructed_pointcloud = trilinear_devoxelize(voxel_out, norm_coords, self.voxelizer.r)
        
        return reconstructed_pointcloud

In [9]:
def chamfer_distance(p1, p2):
    """
    Compute the Chamfer Distance between two point clouds.
    
    Args:
    - p1 (torch.Tensor): A tensor of shape (B, N, D) representing a batch of point clouds, each of which has N points of dimension D.
    - p2 (torch.Tensor): A tensor of the same shape as p1.
    
    Returns:
    - distance (torch.Tensor): A tensor of shape (B,) representing the Chamfer Distance for each pair of point clouds in the batch.
    """
    
    # Compute the pairwise squared distances between points
    # p1 has shape (B, N, D) and p2 has shape (B, M, D)
    # The resulting dists will have shape (B, N, M)
    dists = torch.sum(p1**2, dim=2).unsqueeze(2) + torch.sum(p2**2, dim=2).unsqueeze(1) - 2 * torch.matmul(p1, p2.permute(0, 2, 1))
    
    # For each point in p1, find the closest distance in p2
    min_dists_p1 = torch.min(dists, dim=2)[0]  # Shape (B, N)
    
    # For each point in p2, find the closest distance in p1
    min_dists_p2 = torch.min(dists, dim=1)[0]  # Shape (B, M)
    
    # Combine the two distances by taking the average
    distance = (torch.sum(min_dists_p1, dim=1) + torch.sum(min_dists_p2, dim=1)) / 2
    
    return distance

VAE

In [10]:
class PointCloudVAE(nn.Module):
    def __init__(self, num_points=2048, latent_size=256):
        super(PointCloudVAE, self).__init__()
        
        self.encoder = VariationalVoxelNetEncoder(num_points=num_points, latent_size=latent_size)
        self.decoder = VoxelNetDecoder(num_points=num_points, latent_size=latent_size)

    def forward(self, x):
        # Encode input point cloud and get latent variable z
        z = self.encoder(x)
        
        # Decode z to get the reconstructed voxel grid
        reconstructed_voxel = self.decoder(z)
        
        # Convert the reconstructed voxel grid to point cloud
        # If the decoder doesn't return a point cloud and instead gives a voxel representation, then you would need a devoxelizer here.
        
        return reconstructed_voxel, self.encoder.kl


In [11]:
# The vae_loss function needs to be updated since we won't have mu and logvar separately
def vae_loss(reconstructed, original, kl_div, beta=0.0):
    # Reconstruction loss
    recon_loss = chamfer_distance(reconstructed, original).mean()
    
    # Return combined loss
    return recon_loss + beta * kl_div

In [12]:
# Create a VAE
vae = PointCloudVAE(num_points=2048, latent_size=256)

TypeError: apply() takes no keyword arguments

Train

In [None]:
def train_vae(model, train_loader, val_loader, optimizer, epochs=100, device='cuda', print_interval=100):
    """
    Trains the VAE model.
    
    Args:
        model: The VAE model to train.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: Optimizer for training.
        epochs: Number of epochs for training.
        device: Device to move the model to. 'cuda' or 'cpu'.
        print_interval: Number of batches after which training loss is printed.
    """
    model.to(device)
    best_val_loss = float('inf')  # Initialize with a high value

    def compute_average_losses(total_loss, total_kl, total_recon, loader_len):
        return total_loss / loader_len, total_kl / loader_len, total_recon / loader_len

    for epoch in range(epochs):
        model.train()
        total_train_loss, total_kl_div, total_recon_loss = 0, 0, 0
        
        for i, batch in enumerate(train_loader):
            point_clouds = batch['pointcloud'].to(device)
            
            optimizer.zero_grad()
            reconstructed, kl_div = model(point_clouds)
            
            recon_loss = chamfer_distance(reconstructed, point_clouds).mean()
            combined_loss = vae_loss(reconstructed, point_clouds, kl_div)
            
            combined_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_train_loss += combined_loss.item()
            total_kl_div += kl_div.item()
            total_recon_loss += recon_loss.item()
            
            if (i+1) % print_interval == 0:
                print(f"Epoch {epoch+1}, Batch {i+1} - Combined Loss: {combined_loss.item():.4f}, KL Div: {kl_div.item():.4f}, Recon Loss: {recon_loss.item():.4f}")
        
        avg_train_loss, avg_train_kl_div, avg_train_recon_loss = compute_average_losses(total_train_loss, total_kl_div, total_recon_loss, len(train_loader))

        # Validation
        model.eval()
        total_val_loss, total_val_kl_div, total_val_recon_loss = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                point_clouds = batch['pointcloud'].to(device)
                reconstructed, kl_div = model(point_clouds)
                
                recon_loss = chamfer_distance(reconstructed, point_clouds).mean()
                combined_loss = vae_loss(reconstructed, point_clouds, kl_div)
                
                total_val_loss += combined_loss.item()
                total_val_kl_div += kl_div.item()
                total_val_recon_loss += recon_loss.item()

        avg_val_loss, avg_val_kl_div, avg_val_recon_loss = compute_average_losses(total_val_loss, total_val_kl_div, total_val_recon_loss, len(val_loader))
        
        # Print epoch summary
        print(f"\nEpoch: {epoch+1} Summary:")
        print(f"Train Avg Combined Loss: {avg_train_loss:.4f}, Train Avg KL Div: {avg_train_kl_div:.4f}, Train Avg Recon Loss: {avg_train_recon_loss:.4f}")
        print(f"Val Avg Combined Loss: {avg_val_loss:.4f}, Val Avg KL Div: {avg_val_kl_div:.4f}, Val Avg Recon Loss: {avg_val_recon_loss:.4f}\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_vae_weights.pth')


In [None]:
vae = PointCloudVAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001, weight_decay=0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_vae(vae, train_loader, val_loader, optimizer,device=device)

Visualization

In [None]:
def validate_vae(model, val_loader, device='cuda'):
    """
    Validates the VAE model and returns the average validation loss.
    """
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            point_clouds = batch['pointcloud'].to(device)
            reconstructed, kl_div = model(point_clouds)
            loss = vae_loss(reconstructed, point_clouds, kl_div)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader.dataset)
    return avg_val_loss


In [None]:
def visualize_reconstruction(model, dataset, num_samples=5, device='cuda'):
    """
    Visualize the original and reconstructed point clouds side by side.
    """
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    model.eval()

    for idx in indices:
        sample = dataset[idx]
        points = sample['pointcloud'].unsqueeze(0).to(device)
        with torch.no_grad():
            reconstructed, _ = model(points)
        reconstructed = reconstructed.squeeze(0).cpu().numpy()
        points = points.squeeze(0).cpu().numpy()

        # Plotting
        fig = plt.figure(figsize=(15, 7))
        
        # Original point cloud
        ax1 = fig.add_subplot(121, projection='3d')
        ax1.scatter(points[:, 0], points[:, 1], points[:, 2], s=5)
        ax1.set_title("Original")
        
        # Reconstructed point cloud
        ax2 = fig.add_subplot(122, projection='3d')
        ax2.scatter(reconstructed[:, 0], reconstructed[:, 1], reconstructed[:, 2], s=5)
        ax2.set_title("Reconstructed")

        plt.show()


In [None]:
visualize_reconstruction(vae, val_dset) 

Generate a new model

In [None]:
def sample_from_vae(model, num_samples=5, latent_dim=256, device='cuda'):
    """
    Samples point clouds from the VAE model.
    """
    model.eval()
    with torch.no_grad():
        #
        z_sampled = torch.randn(num_samples, latent_dim).to(device)
        generated_point_clouds = model.decoder(z_sampled)
    return generated_point_clouds


In [None]:
def visualize_generated_point_clouds(model, num_samples=5, latent_dim=256, device='cuda'):
    """
    Visualizes point clouds generated from the VAE.
    """
    generated_point_clouds = sample_from_vae(model, num_samples, latent_dim, device)
    for i in range(num_samples):
        points = generated_point_clouds[i].cpu().numpy()

        # Plotting
        fig = plt.figure(figsize=(7, 7))
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=5)
        ax.set_title(f"Generated Point Cloud {i+1}")

        plt.show()


In [None]:
visualize_generated_point_clouds(vae)
