In [33]:
# table, car, chair, airplane, sofa, rifle(replace this category with something else), lamp
# Potential todo : apply certain transformations like : rotations, flip operations etc.

# Use separate position encoder for both encoder and decoders
# TODO : Multi-scale clouds

# TODO : Change the patch embedding for visible token to Mini-PointNet

# Point-MAE (WIP)

Implementation of the paper Point-MAE :- https://arxiv.org/abs/2203.06604

Note :- This code expects point clouds in a h5py with the group structure f --> data --> Category (eg. Table / Chair) etc from the **ShapeNet** data. You'll find a sample version of data in the data folder for reference. Once the data is read, one might see samples like the one below visualized using package **k3d**.

<img data-align="center" src="./Screenshot%202024-01-17%20at%2015.06.53.png" width="650" />

In [34]:
import torch
import torch.nn as nn
from torch import optim
from torchinfo import summary
from sklearn.neighbors import NearestNeighbors

from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import os
from sys import platform
from torch_geometric.nn import fps, knn

from torch.profiler import profile, record_function, ProfilerActivity
from torch.utils.data import TensorDataset, DataLoader
import h5py
import k3d

In [35]:
def test_patching(original, res2):
    plot = k3d.plot()
    rand_item = np.random.choice(np.arange(len(original)))
    # Generate some random colors for visualization
    colors = [
        0xbd72b5, 0x64cb1e, 0x489d8d, 0xb65df1, 0x3f6a69, 0x0708d2, 0xad48d2, 0x2b440f, 0xbc0b26, 0xae6c42,\
    ]*10
    np.random.shuffle(colors)

    pt_samples = k3d.points(positions=original[rand_item], point_size=0.008, shader='3d',color=0x3f6bc5)
    plot += pt_samples

    for x, c in zip(res2[rand_item], colors[:len(res2)]):
        plot += k3d.points(positions=x.numpy()+0.6, point_size=0.008, shader='3d',color=c)

    plot.display()

In [36]:
NUM_PATCHES = 96
CLOUD_POINTS = 4096
FPS_RATIO = round(NUM_PATCHES/CLOUD_POINTS, 5)
K_NN_K = 32
BATCH_SIZE = 36
MASK_RATIO = 0.3
EMBEDDING_DIM = 384 # For each single patch

MLP_UNITS = EMBEDDING_DIM * 2 # for the feed-foward network in Encoder/Decoders
NUM_HEADS = 6
ENCODER_DEPTH = 12

EPOCHS = 30

In [37]:
device = "cpu"
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
print("Device in use :", device)

Device in use : mps


In [38]:
def my_chamfer_dist(pc1, pc2, batch=True):
    if platform == "darwin":
        pc1 = pc1.to("cpu")
        pc2 = pc2.to("cpu")

    assert pc1.size() == pc2.size()
    if batch:
        distances = []
        for p1, p2 in zip(pc1, pc2):
            d = torch.cdist(p1,p2).min(1).values.mean() + torch.cdist(p2,p1).min(1).values.mean()
            distances.append(d)
        return torch.stack(distances).sum()
            
    return torch.cdist(pc1,pc2).min(1).values.mean() + torch.cdist(pc2,pc1).min(1).values.mean()

In [39]:
class FPS(nn.Module):
    """
        Perform FPS sampling on batch of point cloud items.
    """
    def __init__(self, ratio):
        super().__init__()
        assert ratio > 0 and ratio <= 1
        self.ratio = ratio
    
    def forward(self, clouds):
        new_clouds = []
        # Batch data check
        assert len(clouds.size()) > 2, "make sure data is batched version"
        B = clouds.size()[0]
        N = clouds.size()[1]

        batch = (torch.arange (0, B*N) // N).contiguous()
        clouds = clouds.view(-1, 3).float().contiguous()
        fps_cloud_indices = fps(clouds, batch, ratio=self.ratio)
        return clouds[fps_cloud_indices, :].view(B, -1, 3).float()

class kNN(nn.Module):
    def __init__(self,k):
        super().__init__()
        assert k > 0
        self.k = k
    def forward(self, full_cloud, centers):
        """
            Full_cloud : Full cloud
            Centers : Coordinates of picked centers
        """
        assert len(full_cloud) == len(centers), "Both need to be of full size"
        assert len(full_cloud.size()) > 2, "make sure data is batched version"
        assert len(centers.size()) > 2, "make sure data is batched version"
        
        patch_clouds = []
        center_clouds = []

        B_x = full_cloud.size()[0]
        N_x = full_cloud.size()[1]

        B_y = centers.size()[0]
        N_y = centers.size()[1]

        batch_x = (torch.arange (0, B_x*N_x) // N_x).contiguous()
        batch_y = (torch.arange (0, B_y*N_y) // N_y).contiguous()

        full_cloud = full_cloud.view(-1, 3).float().contiguous()
        centers = centers.view(-1, 3).float().contiguous()

        assign_indices = knn(full_cloud, centers, batch_x = batch_x, batch_y=batch_y, k=self.k)
        x = full_cloud[assign_indices[1], :]
        x = x.view(B_x, -1, self.k, 3).float()
        
        # Add centers to the already selected neighbors for each patch
        c = centers.view(B_y, -1, 1, 3).float()
        x = torch.cat([x, c], 2)
        x = x - c
        return x, c

# THIS ONE WILL BE RANDOM MASKING
# Create a masking layer which performs the operations of randomized masking of certain tokens
class RandMasking(nn.Module):
    def __init__(self, ratio, batch_size, patch_points):
        super().__init__()
        assert ratio > 0 and ratio <= 1
        self.mask_ratio = ratio

        # define a special learnable shared weighted mask token
        self.mask_token = torch.nn.Parameter(torch.zeros((patch_points, 3)))
    
    def forward(self, patches):
        assert len(patches.size()) > 2, "make sure data is batched version"

        # create random indices to apply patch to
        len_ = patches.size()[1]
        indcs = np.arange(len_)
        np.random.shuffle(indcs)
        total = int(self.mask_ratio*len_)
        mask_indcs = indcs[:total]
        valid_token_indcs = indcs[total:]

        valid_tokens = patches[:, valid_token_indcs, :, :]
        masked_tokens = patches[:, mask_indcs, :, :]
        valid_mask_patches = patches[:, mask_indcs, :, :] # Return the actual mask values to serve as output values
        for b in range(masked_tokens.size()[0]):
            for i in range(len(mask_indcs)):
                masked_tokens[b,i,:,:] = self.mask_token
        return valid_tokens, masked_tokens, valid_mask_patches, mask_indcs, valid_token_indcs
        
# Create a separate masking : To remove certain object complete parts instead of random masking
# that way we'll be able to check how well symmetry is playing a role in helping the model learn inherant geometry
class SpecialMasking(nn.Module):
    pass

In [9]:
# A helper block for positional embeddings of the patches.
class PE_MLP(nn.Module):
    def __init__(self, embedding_dim, input_dim = 3):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.input_dim = input_dim
        self.mid_units = 128

        self.norm = nn.LayerNorm(normalized_shape = self.embedding_dim)
        self.DLL = nn.Sequential(
            nn.Linear(in_features = self.input_dim, out_features = self.mid_units),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(in_features = self.mid_units, out_features = self.embedding_dim),
        )
    def forward(self, x):
        return self.norm(self.DLL(x))


### Define the complete Patch Embedding Layer encompassing all operations
class Embedding(nn.Module):
    def __init__(self, fps_ratio, batch_size, k_nn_k, msk_ratio, embed_dim, num_patches):
        super().__init__()
        self.fps = FPS(fps_ratio)
        self.knn = kNN(k_nn_k-1) # Number of neighbors excluding center
        self.masking_module = RandMasking(msk_ratio, batch_size, k_nn_k)
        self.flatt_op = nn.Flatten(2,3) # Layer to flatten points in a given patch for a point cloud

        # separate MLP blocks for encoder and decoder
        self.encoder_MLP = PE_MLP(embed_dim, k_nn_k*3)
        self.decoder_MLP = PE_MLP(embed_dim, k_nn_k*3)

        # Also create another MLP block for centers
        self.centers_embedding = PE_MLP(embed_dim, 3)
    
    def forward(self, inp):
        c = self.fps(inp)
        patches, centers = self.knn(inp, c)
        valid_patches, masked_patches, valid_mask_patches, mask_indcs, val_indcs = self.masking_module(patches)
        
        # Flatten the valid patches and masked patches (x,y,z) coordinaters
        valid_patches = self.flatt_op(valid_patches)
        masked_patches = self.flatt_op(masked_patches)
        
        # masked_patches = self.flatt_op(masked_patches)
        valid_patches = self.encoder_MLP(valid_patches)
        masked_patches = self.decoder_MLP(masked_patches)

        valid_center_embeddings = self.centers_embedding(centers[:,val_indcs,:,:]).flatten(2,3)
        masked_center_embeddings = self.centers_embedding(centers[:,mask_indcs,:,:]).flatten(2,3)
        
        valid_patches = valid_patches + valid_center_embeddings
        masked_patches = masked_patches + masked_center_embeddings

        return valid_patches, masked_patches, valid_mask_patches

In [10]:
# Setup the Transformer Encoder-Decoder (with prediction head as well)

# Feed-forward block part of encoder and decoders
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, units):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.units = units
        self.DLL = nn.Sequential(
            nn.Linear(in_features = self.embedding_dim, out_features = self.units),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(in_features = self.units, out_features = self.embedding_dim),
            nn.GELU()
        )
    def forward(self, x):
        return self.DLL(x)

class SelfAttentionBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super().__init__()        
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim

        # self.norm = nn.LayerNorm(embedding_dim)
        self.attention_heads = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embedding_dim, dropout=0.2, batch_first=True)

    def forward(self,inp):
        # x = self.norm(inp)
        x = inp
        out,_ = self.attention_heads(query=x, key=x, value=x, need_weights=False)
        return out

class CrossAttentionBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super().__init__()        
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.attention_heads = nn.MultiheadAttention(num_heads=num_heads, embed_dim=embedding_dim, dropout=0.2, batch_first=True)

    def forward(self, k, v, q):
        # Query will come from Masked patches and key & value will come from encoder
        out,_ = self.attention_heads(query=q, key=k, value=v, need_weights=False)
        return out

class Encoder(nn.Module):
    def __init__(self, embedding_dim, mlp_units, num_heads):
        super().__init__()
        self.attention = SelfAttentionBlock(embedding_dim, num_heads=num_heads)
        self.dense = FeedForward(embedding_dim, mlp_units)
        self.norm = nn.LayerNorm(normalized_shape = embedding_dim)
    
    def forward(self,inp):
        x = self.norm(self.attention(inp)+inp)
        return self.norm(self.dense(x) + x)

class Decoder(nn.Module):
    def __init__(self, embedding_dim, mlp_units, num_heads):
        super().__init__()
        self.cross_attention = CrossAttentionBlock(embedding_dim, num_heads=num_heads)
        
        self.dense = FeedForward(embedding_dim, mlp_units)
        self.norm = nn.LayerNorm(normalized_shape = embedding_dim)
    
    def forward(self, decod_inp, encod_inp):
        x = self.cross_attention( encod_inp, encod_inp, decod_inp)
        return self.norm(self.dense(x) + x)

class PT_MAE(nn.Module):
    def __init__(self, fps_ratio, batch_size, k_nn_k, msk_ratio, embed_dim, num_patches, mlp_units, num_heads, encoder_depth, device):
        super().__init__()
        self.device = device
        self.emb = Embedding(fps_ratio, batch_size, k_nn_k, msk_ratio, embed_dim, num_patches)
        
        self.Encoders = nn.Sequential(*[Encoder(embed_dim, mlp_units, num_heads) for i in range(encoder_depth)]).to(device)

        # NOTE : For decoder block use much lower number of layers
        # self.Decoders = nn.Sequential(*[Decoder(embed_dim, mlp_units, num_heads) for i in range(encoder_depth//2)])
        
        self.D1 = Decoder(embed_dim, mlp_units, num_heads).to(device)
        self.D2 = Decoder(embed_dim, mlp_units, num_heads).to(device)
        self.D3 = Decoder(embed_dim, mlp_units, num_heads).to(device)
        self.D4 = Decoder(embed_dim, mlp_units, num_heads).to(device)

        # Create a prediction head
        self.Pred_Head = nn.Linear(embed_dim, 3*k_nn_k).to(device)
    
    def forward(self, inp):
        valid_patches, masked_patches, valid_mask_patches = self.emb(inp)
        
        valid_patches = valid_patches.to(self.device)
        masked_patches = masked_patches.to(self.device)
        valid_mask_patches = valid_mask_patches.to(self.device)
        x = self.Encoders(valid_patches)
        decoded = self.D2(self.D1(masked_patches, x), x)
        decoded = self.D3(decoded, x)
        decoded = self.D4(decoded, x)

        out = self.Pred_Head(decoded)
        return out, valid_mask_patches

In [11]:
class Model:
    def __init__(self, fps_ratio, batch_size, k_nn_k, msk_ratio, embed_dim, num_patches, mlp_units, num_heads, encoder_depth, epochs, model_path, device):
        self.batch_size = batch_size
        self.epochs = epochs
        self.device = device
        self.learning_rate = 0.0002
        self.model_path = model_path

        # Check if model_path is non empty/present
        if os.path.exists(model_path):
            print("Loading Existing model")
            self.pt_mae = torch.load(model_path)
        else:
            self.pt_mae = PT_MAE(fps_ratio, batch_size, k_nn_k, msk_ratio, embed_dim, num_patches, mlp_units, num_heads, encoder_depth, device = self.device)
        # self.pt_mae = self.pt_mae.to(self.device)

        # Setup optimizers
        self.optimizer = optim.Adam(self.pt_mae.parameters(), lr=self.learning_rate)

    def profile_model(self,):
        with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
            with h5py.File("../Data/Mini-PointNet.h5py", "r") as f:
                inputs = torch.from_numpy(f["data"]["Chair"][:40].astype(np.float32))
                with record_function("model_inference"):
                    self.pt_mae(inputs)
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=12))
    
    def train(self, dataset_file_path = "../Data/Mini-PointNet.h5py", epochs = None):
        # Prepare data for training
        data = []
        with h5py.File(dataset_file_path, "r") as f:
            for c in ("Table", "Chair", "Sofa:Couch", "Lamp"):
                data.append(f["data"][c][:].astype(np.float32))
        data = np.vstack(data)
        np.random.shuffle(data)

        train_dataset = TensorDataset(torch.from_numpy(data))
        train_loader = DataLoader(dataset = train_dataset, batch_size = self.batch_size, shuffle=True)

        if epochs is not None:
            epochs = epochs
        else:
            epochs = self.epochs
        
        for epoch in tqdm(range(1, epochs + 1)):
            for batch_idx, data in enumerate(train_loader):
                data = data[0]
                out, valid_mask_patches = self.pt_mae(data)
                valid_mask_patches = valid_mask_patches.flatten(2,3)
                loss = my_chamfer_dist(out, valid_mask_patches)

                if batch_idx % 600 == 0:
                    print(f"Existing Loss in Epoch : {epoch}, Batch : {batch_idx} ===>", round(loss.item(),3))
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            
            # Save model after every 2 epochs
            if epoch % 2 == 0:
                torch.save(self.pt_mae,self.model_path)
        # save the model at the end also
        torch.save(self.pt_mae,self.model_path)

In [12]:
m = Model(FPS_RATIO, BATCH_SIZE, K_NN_K, MASK_RATIO, EMBEDDING_DIM, NUM_PATCHES, MLP_UNITS, NUM_HEADS, ENCODER_DEPTH, EPOCHS, "./Saved_Model2", device = device)

Loading Existing model


In [13]:
m.train(epochs = 20)

  0%|          | 0/20 [00:00<?, ?it/s]

Existing Loss in Epoch : 1, Batch : 0 ===> 12.684
Existing Loss in Epoch : 2, Batch : 0 ===> 13.159
Existing Loss in Epoch : 3, Batch : 0 ===> 12.762
Existing Loss in Epoch : 4, Batch : 0 ===> 12.994
Existing Loss in Epoch : 5, Batch : 0 ===> 12.618
Existing Loss in Epoch : 6, Batch : 0 ===> 12.835
Existing Loss in Epoch : 7, Batch : 0 ===> 13.437
Existing Loss in Epoch : 8, Batch : 0 ===> 12.905
Existing Loss in Epoch : 9, Batch : 0 ===> 12.076
Existing Loss in Epoch : 10, Batch : 0 ===> 12.323
Existing Loss in Epoch : 11, Batch : 0 ===> 13.569
Existing Loss in Epoch : 12, Batch : 0 ===> 12.517
Existing Loss in Epoch : 13, Batch : 0 ===> 12.696
Existing Loss in Epoch : 14, Batch : 0 ===> 11.84
Existing Loss in Epoch : 15, Batch : 0 ===> 12.147
Existing Loss in Epoch : 16, Batch : 0 ===> 12.459
Existing Loss in Epoch : 17, Batch : 0 ===> 12.108
Existing Loss in Epoch : 18, Batch : 0 ===> 11.603
Existing Loss in Epoch : 19, Batch : 0 ===> 12.456
Existing Loss in Epoch : 20, Batch : 0 ==

In [14]:
if platform == "darwin":
    os.system("say 'Traning Epochs Completed!'")