In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import trimesh
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import os
import random
from copy import copy
import h5py
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


Dataset:

In [3]:
if 0:
    class ShapeNetCoreDataset(Dataset):
        def __init__(self, root_dir, num_points=1024, transform=None):
            """
            Args:
                root_dir (str): Directory with all the ShapeNetCore classes and models.
                num_points (int): Number of points to sample from each model for the point cloud.
                transform (callable, optional): Optional transform to apply to the point cloud.
            """
            self.root_dir = root_dir
            self.num_points = num_points
            self.transform = transform

            # Get list of all files in the dataset (assuming the .off format, but you may need to adjust for other formats)
            self.data_paths = []
            for dirpath, _, filenames in os.walk(self.root_dir):
                for filename in filenames:
                    if filename.endswith('.off'):
                        self.data_paths.append(os.path.join(dirpath, filename))

        def __len__(self):
            return len(self.data_paths)

        def __getitem__(self, idx):
            mesh_path = self.data_paths[idx]
            mesh = trimesh.load_mesh(mesh_path)

            # Sample points from the surface of the model
            points = mesh.sample(self.num_points)
            
            # Convert to PyTorch tensor
            points_tensor = torch.tensor(points, dtype=torch.float32)

            if self.transform:
                points_tensor = self.transform(points_tensor)

            return points_tensor

    # Usage:

    # Adjust the path to where you've stored ShapeNetCore
    root_dir = "/path/to/ShapeNetCore.v2"
    dataset = ShapeNetCoreDataset(root_dir=root_dir, num_points=1024)

    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Example of iterating over the dataloader
    for batch in dataloader:
        print(batch.shape)  # Should print [32, 1024, 3] for batch_size 32 and num_points 1024

In [4]:
#borrow from https://github.com/luost26/diffusion-point-cloud/blob/1e30d48d018820fbc7c67c8b3190215bd41878e4/utils/dataset.py

synsetid_to_cate = {
    '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
    '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
    '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
    '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
    '02954340': 'cap', '02958343': 'car', '03001627': 'chair',
    '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
    '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
    '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
    '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
    '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
    '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
    '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
    '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
    '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
    '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
    '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
    '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
    '04554684': 'washer', '02992529': 'cellphone',
    '02843684': 'birdhouse', '02871439': 'bookshelf',
    # '02858304': 'boat', no boat in our dataset, merged into vessels
    # '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}


class ShapeNetCore(Dataset):

    GRAVITATIONAL_AXIS = 1
    
    def __init__(self, path, cates, split, scale_mode, transform=None):
        super().__init__()
        assert isinstance(cates, list), '`cates` must be a list of cate names.'
        assert split in ('train', 'val', 'test')
        assert scale_mode is None or scale_mode in ('global_unit', 'shape_unit', 'shape_bbox', 'shape_half', 'shape_34')
        self.path = path
        if 'all' in cates:
            cates = cate_to_synsetid.keys()
        self.cate_synsetids = [cate_to_synsetid[s] for s in cates]
        self.cate_synsetids.sort()
        self.split = split
        self.scale_mode = scale_mode
        self.transform = transform

        self.pointclouds = []
        self.stats = None

        self.get_statistics()
        self.load()

    def get_statistics(self):

        basename = os.path.basename(self.path)
        dsetname = basename[:basename.rfind('.')]
        stats_dir = os.path.join(os.path.dirname(self.path), dsetname + '_stats')
        os.makedirs(stats_dir, exist_ok=True)

        if len(self.cate_synsetids) == len(cate_to_synsetid):
            stats_save_path = os.path.join(stats_dir, 'stats_all.pt')
        else:
            stats_save_path = os.path.join(stats_dir, 'stats_' + '_'.join(self.cate_synsetids) + '.pt')
        if os.path.exists(stats_save_path):
            self.stats = torch.load(stats_save_path)
            return self.stats

        with h5py.File(self.path, 'r') as f:
            pointclouds = []
            for synsetid in self.cate_synsetids:
                for split in ('train', 'val', 'test'):
                    pointclouds.append(torch.from_numpy(f[synsetid][split][...]))

        all_points = torch.cat(pointclouds, dim=0) # (B, N, 3)
        B, N, _ = all_points.size()
        mean = all_points.view(B*N, -1).mean(dim=0) # (1, 3)
        std = all_points.view(-1).std(dim=0)        # (1, )

        self.stats = {'mean': mean, 'std': std}
        torch.save(self.stats, stats_save_path)
        return self.stats

    def load(self):

        def _enumerate_pointclouds(f):
            for synsetid in self.cate_synsetids:
                cate_name = synsetid_to_cate[synsetid]
                for j, pc in enumerate(f[synsetid][self.split]):
                    yield torch.from_numpy(pc), j, cate_name
        
        with h5py.File(self.path, mode='r') as f:
            for pc, pc_id, cate_name in _enumerate_pointclouds(f):

                if self.scale_mode == 'global_unit':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = self.stats['std'].reshape(1, 1)
                elif self.scale_mode == 'shape_unit':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1)
                elif self.scale_mode == 'shape_half':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1) / (0.5)
                elif self.scale_mode == 'shape_34':
                    shift = pc.mean(dim=0).reshape(1, 3)
                    scale = pc.flatten().std().reshape(1, 1) / (0.75)
                elif self.scale_mode == 'shape_bbox':
                    pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
                    pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
                    shift = ((pc_min + pc_max) / 2).view(1, 3)
                    scale = (pc_max - pc_min).max().reshape(1, 1) / 2
                else:
                    shift = torch.zeros([1, 3])
                    scale = torch.ones([1, 1])

                pc = (pc - shift) / scale

                self.pointclouds.append({
                    'pointcloud': pc,
                    'cate': cate_name,
                    'id': pc_id,
                    'shift': shift,
                    'scale': scale
                })

        # Deterministically shuffle the dataset
        self.pointclouds.sort(key=lambda data: data['id'], reverse=False)
        random.Random(2020).shuffle(self.pointclouds)

    def __len__(self):
        return len(self.pointclouds)

    def __getitem__(self, idx):
        data = {k:v.clone() if isinstance(v, torch.Tensor) else copy(v) for k, v in self.pointclouds[idx].items()}
        if self.transform is not None:
            data = self.transform(data)
        return data

In [5]:
#https://github.com/luost26/diffusion-point-cloud/blob/1e30d48d018820fbc7c67c8b3190215bd41878e4/utils/data.py
def get_train_val_test_datasets(dataset, train_ratio, val_ratio):
    assert (train_ratio + val_ratio) <= 1
    train_size = int(len(dataset) * train_ratio)
    val_size = int(len(dataset) * val_ratio)
    test_size = len(dataset) - train_size - val_size
    
    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
    return train_set, val_set, test_set


def get_train_val_test_loaders(dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers):
    train_set, val_set, test_set = get_train_val_test_datasets(dataset, train_ratio, val_ratio)

    train_loader = DataLoader(train_set, train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, val_test_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader


def get_data_iterator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, data in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

Encoder:

In [6]:
import torch.nn as nn
import torch.nn.functional as F

class TransformNet(nn.Module):
    """
    A mini-PointNet as T-Net in the main PointNet paper.
    Used for input transform (3x3 matrix) or feature transform (64x64 matrix).
    """
    def __init__(self, k=3):
        super(TransformNet, self).__init__()
        
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
        
        # Initialize the weight for the last layer (fc3) to be identity
        self.fc3.weight.data = torch.eye(k).view(-1, k).repeat(1, k)
        self.fc3.bias.data.zero_()

    def forward(self, x):
        # x: [B, k, N]
        B = x.size(0)
        k = x.size(1)

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        x = torch.max(x, 2, keepdim=True)[0]  # Global max pooling
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        
        # The transform matrix
        transform = self.fc3(x)
        transform = transform.view(B, k, k)

        return transform

class PointNetEncoder(nn.Module):
    def __init__(self, num_points=1024, global_feat=True):
        super(PointNetEncoder, self).__init__()
        
        self.global_feat = global_feat
        
        # Input transform (3x3 matrix for 3D point cloud)
        self.input_transform = TransformNet(k=3)
        
        # Feature transform (64x64 matrix)
        self.feature_transform = TransformNet(k=64)
        
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, x):
        # x: [B, N, 3]
        B, N, _ = x.size()
        
        # Transform input
        input_transform = self.input_transform(x.permute(0, 2, 1))
        x = torch.bmm(x, input_transform)
        
        # First layers
        x = F.relu(self.bn1(self.conv1(x.permute(0, 2, 1))))
        
        # Transform features
        feature_transform = self.feature_transform(x)
        x = torch.bmm(x.permute(0, 1, 2), feature_transform).permute(0, 2, 1)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        
        # Global feature (global max pooling)
        global_feature = torch.max(x, 2, keepdim=True)[0]
        global_feature = global_feature.view(B, -1)
        
        if self.global_feat:
            return global_feature
        else:
            return global_feature, x

# Usage:

# Instantiate the encoder
encoder = PointNetEncoder(num_points=1024, global_feat=True)

# Let's say we have a batch of point clouds with size [32, 1024, 3]
# x = torch.randn(32, 1024, 3)
# global_feature = encoder(x)

# The output `global_feature` is the encoded feature of the entire point cloud.

Decoder:

In [7]:
class PointNetDecoder(nn.Module):
    def __init__(self, num_points=1024, latent_size=256):
        super(PointNetDecoder, self).__init__()
        
        # Decoder
        self.fc1 = nn.Linear(latent_size, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, num_points * 3)  # times 3 for the 3D coordinates (x, y, z) of each point
        
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.bn3(self.fc3(x)))
        
        # Get the points (no activation in the last layer, since points can have any real-value coordinates)
        x = self.fc4(x)
        
        # Reshape output to [batch_size, num_points, 3]
        return x.view(x.size(0), -1, 3)

# Usage:

# Instantiate the decoder
latent_size = 1024  # This should match the output size of your encoder
decoder = PointNetDecoder(num_points=1024, latent_size=latent_size)

# Assuming `z` is the encoded latent representation from the encoder
# z = torch.randn(32, latent_size)
# decoded_points = decoder(z)


VAE

In [8]:
class PointCloudVAE(nn.Module):
    def __init__(self, num_points=1024, latent_size=256):
        super(PointCloudVAE, self).__init__()
        
        self.encoder = PointNetEncoder(num_points=num_points, global_feat=True)
        self.decoder = PointNetDecoder(num_points=num_points, latent_size=latent_size)
        
        # Define layers for producing mu and logvar from the encoded representation
        self.fc_mu = nn.Linear(1024, latent_size)
        self.fc_logvar = nn.Linear(1024, latent_size)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Encode input point cloud
        encoded = self.encoder(x)
        
        # Get mu and logvar
        mu = self.fc_mu(encoded)
        logvar = self.fc_logvar(encoded)
        
        # Reparameterization trick
        z = self.reparameterize(mu, logvar)
        
        # Decode z to get the reconstructed point cloud
        reconstructed = self.decoder(z)
        
        return reconstructed, mu, logvar

# Loss function for VAE
def vae_loss(reconstructed, original, mu, logvar, beta=1.0):
    # Reconstruction loss: Assuming Mean Squared Error (MSE) for simplicity
    # Depending on the use case, the Chamfer distance or other suitable metrics might be used
    recon_loss = F.mse_loss(reconstructed, original, reduction='sum')
    
    # KL divergence loss
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Combine both losses
    return recon_loss + beta * kl_div

# Usage:

# Create a VAE
vae = PointCloudVAE(num_points=1024, latent_size=256)

# Assume `point_cloud` is your input batch of point clouds of size [32, 1024, 3]
# point_cloud = torch.randn(32, 1024, 3)

# Forward pass
# reconstructed, mu, logvar = vae(point_cloud)

# Compute loss
# loss = vae_loss(reconstructed, point_cloud, mu, logvar)


Train

In [9]:
def train_vae(model, train_loader, val_loader, optimizer, epochs=100, device='cuda'):
    """
    Trains the VAE model.
    """
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        # Training
        for batch in tqdm(train_loader):
            point_clouds = batch['pointcloud'].to(device)
            
            optimizer.zero_grad()
            reconstructed, mu, logvar = model(point_clouds)
            
            loss = vae_loss(reconstructed, point_clouds, mu, logvar)
            loss.backward()
            
            train_loss += loss.item()
            optimizer.step()
            
        avg_train_loss = train_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                point_clouds = batch['pointcloud'].to(device)
                reconstructed, mu, logvar = model(point_clouds)
                loss = vae_loss(reconstructed, point_clouds, mu, logvar)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader.dataset)
        
        print(f"Epoch: {epoch+1}, Avg Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")

# Usage
path_to_shapenet = "data/shapenet.hdf5"
categories = ['airplane', 'car']  # Use the categories you're interested in
split = 'train'
scale_mode = 'global_unit'  # or whichever scale mode you want to use

dataset = ShapeNetCore(path_to_shapenet, categories, split, scale_mode)
train_loader, val_loader, _ = get_train_val_test_loaders(dataset, train_ratio=0.8, val_ratio=0.1, train_batch_size=32, val_test_batch_size=32, num_workers=4)

vae = PointCloudVAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

train_vae(vae, train_loader, val_loader, optimizer)


  0%|          | 0/161 [00:01<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x256 and 9x3)

Visualization

In [None]:
def visualize_point_cloud(point_cloud):
    """
    Visualizes a single 3D point cloud using matplotlib.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], marker='o')
    plt.show()

def validate_vae(model, dataloader, device='cuda'):
    """
    Validates the PointCloudVAE model and provides visualization tools.
    
    Args:
        model: The PointCloudVAE model instance.
        dataloader: A PyTorch DataLoader for validation data.
        device: Device to run on ('cuda' for GPU, 'cpu' for CPU).
    """
    model.eval()  # Set the model to evaluation mode
    model.to(device)
    
    total_loss = 0
    with torch.no_grad():
        for point_clouds in dataloader:
            point_clouds = point_clouds.to(device)
            
            reconstructed, mu, logvar = model(point_clouds)
            loss = vae_loss(reconstructed, point_clouds, mu, logvar)
            total_loss += loss.item()
            
            # Visualization for the first sample in the batch
            original = point_clouds[0].cpu().numpy()
            recon = reconstructed[0].cpu().numpy()
            
            print("Original Point Cloud:")
            visualize_point_cloud(original)
            print("Reconstructed Point Cloud:")
            visualize_point_cloud(recon)
            
    avg_loss = total_loss / len(dataloader.dataset)
    print(f"Avg Validation Loss: {avg_loss:.4f}")

# Usage:
# Assuming you have a dataloader `val_loader` for validation data
# validate_vae(vae, val_loader, device='cuda')
