In [6]:
#import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from point_transformer_pytorch import PointTransformerLayer

In [10]:
class PointTransformerBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pos_mlp_hidden_dim=32, attn_mlp_hidden_mult=4, num_neighbors=16):
        super(PointTransformerBlock, self).__init__()
        self.fc1 = nn.Linear(in_channels, in_channels)
        self.point_transformer_layer = PointTransformerLayer(
            dim=in_channels,
            pos_mlp_hidden_dim=pos_mlp_hidden_dim,
            attn_mlp_hidden_mult=attn_mlp_hidden_mult,
            num_neighbors=num_neighbors 
        )
        self.fc2 = nn.Linear(in_channels, out_channels)

    def forward(self, x, p):
        ###### Is all this reshaping required ??? ######
        ###### What do I do with the position ######
        # x shape: (batch_size, seq_len, num_points, in_channels)
        batch_size, seq_len, num_points, in_channels = x.shape
        residual = x
        x = self.fc1(x)
        mask = torch.ones(batch_size * seq_len, num_points).bool().to(x.device)
        p = p.view(batch_size * seq_len, num_points, 3)
        x = x.view(batch_size * seq_len, num_points, in_channels)
        x = self.point_transformer_layer(x, p, mask=mask)
        y = self.fc2(x.view(batch_size, seq_len, num_points, -1) + residual)
        return y, p
    

# create random variable
N = 512 # number of points
feature_size = 13 
batch_size = 32
seq_len = 1
x = torch.randn(batch_size, seq_len, N, feature_size)  # (batch_size, num_points, in_channels)
p = torch.randn(batch_size, seq_len, N, 3)  # (batch_size * seq_len, num_points, 3)
model = PointTransformerBlock(in_channels=feature_size, out_channels=128)
y, p_out = model(x, p)
print(y.shape)  # should be (batch_size, seq_len, N, out_channels)
print(p_out.shape)  # should be (batch_size * seq_len, N, 3)

torch.Size([16, 1, 512, 128])
torch.Size([16, 512, 3])


In [15]:
# input (I, 13)
# input neighbours (J, 13)
class LrgNet(nn.Module):
    def __init__(self, batch_size=32,
                 seq_len=1, # Number of region growing steps unrolled
                 num_inlier_points=512,
                 num_neighbour_points=512,
                 feature_dim=10):
        super(LrgNet, self).__init__()
        self.B1 = [128, 128]
        self.B2 = [128, 256, 512, 1024]
        self.B3 = [512, 256, 128]
        
        # Do I include this in the PointTransformer blocks?
        # Initial up-dimension layers for inlier and neighbour points (to be updated)
        self.fc1 = nn.Linear(feature_dim, self.B1[0])
        self.neigbour_fc1 = nn.Linear(feature_dim, self.B1[0])

        # Point Transformer Layers for inlier set
        for i in range(len(self.B1)-1):
            inlier_layer =  PointTransformerBlock(in_channels=self.B1[i], out_channels=self.B1[i+1])
            neighbour_layer = PointTransformerBlock(in_channels=self.B1[i],out_channels=self.B1[i+1])
            ######### Need a diagram drawn with dimensions #########
            # inlier and linear have different input shapes
            setattr(self, 'B1_inlier_' + str(i), inlier_layer)
            setattr(self, 'B1_neighbour_' + str(i), neighbour_layer)
        
        # Point Transformer Layers for outlier set
        for i in range(len(self.B2)-1):
            inlier_layer =  PointTransformerBlock(in_channels=self.B2[i], out_channels=self.B2[i+1])
            neighbour_layer = PointTransformerBlock(in_channels=self.B2[i],out_channels=self.B2[i+1])
            setattr(self, 'B2_inlier_' + str(i), inlier_layer)
            setattr(self, 'B2_neighbour_' + str(i), neighbour_layer)

        # Global Average Pooling
        ###### Does this work as intended ######
        # https://docs.pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool2d.html
        self.max_pool = torch.nn.AdaptiveMaxPool2d((1, 1))
        
        # B3
        for i in range(len(self.B3)-1):
            inlier_layer =  PointTransformerBlock(in_channels=self.B3[i], out_channels=self.B3[i+1], pos_mlp_hidden_dim=64)
            neighbour_layer = PointTransformerBlock(in_channels=self.B3[i],out_channels=self.B3[i+1], pos_mlp_hidden_dim=64)
            setattr(self, 'B3_inlier_' + str(i), inlier_layer)
            setattr(self, 'B3_neighbour_' + str(i), neighbour_layer)
        
        # Final classification layer
        self.remove_mask = nn.Linear(self.B3[-1], 2)  
        self.add_mask = nn.Linear(self.B3[-1], 2)  


    def forward(self, inlier_points, neighbour_points):
        p = inlier_points[-3:]
        p_neigh = neighbour_points[-3:]
        x = inlier_points[:-3] 
        x_neigh = neighbour_points[:-3]
        
        x = self.fc1(x)
        x_neigh = self.neigbour_fc1(x_neigh)

        for i in range(len(self.B1)-1):
            inlier_layer = getattr(self, 'B1_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B1_neighbour_' + str(i))
            x, p = inlier_layer(x, p)
            x_neigh, p_neigh = neighbour_layer(x_neigh, p_neigh)

        residual = x
        residual_neigh = x_neigh

        for i in range(len(self.B2)-1):
            inlier_layer = getattr(self, 'B2_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B2_neighbour_' + str(i))
            x, p = inlier_layer(x, p)
            x_neigh, p_neigh  = neighbour_layer(x_neigh, p_neigh)

        return x, x_neigh, p, p_neigh

        # Global Max Pooling
        x = x.permute(0, 3, 1, 2)
        x = self.max_pool(x)
        x = x.view(x.size(0), -1)
        x_neigh = x_neigh.permute(0, 3, 1, 2)
        x_neigh = self.max_pool(x_neigh)
        x_neigh = x_neigh.view(x_neigh.size(0), -1)

        # B3
        for i in range(len(self.B3)-1):
            inlier_layer = getattr(self, 'B3_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B3_neighbour_' + str(i))
            x = inlier_layer(x, p)
            x_neigh = neighbour_layer(x_neigh, p_neigh)

        # Final classification
        remove_mask_logits = self.remove_mask(x)
        add_mask_logits = self.add_mask(x_neigh)

        return remove_mask_logits, add_mask_logits
    
model = LrgNet(batch_size=batch_size,
                seq_len=seq_len,
                num_inlier_points=N,
                num_neighbour_points=N,
                feature_dim=feature_size)
x_neigh = torch.randn(batch_size, seq_len, N, feature_size)  # (batch_size, num_points, in_channels)
x, x_neigh, p, p_neigh = model(x, x_neigh)

TypeError: forward() takes 2 positional arguments but 3 were given

In [4]:
import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4,
    num_neighbors = 16          # only the 16 nearest neighbors would be attended to for each point
)

feats = torch.randn(1, 2048, 128)
pos = torch.randn(1, 2048, 3)
mask = torch.ones(1, 2048).bool()

out = attn(feats, pos, mask = mask) # (1, 16, 128)

# print output shape
out.shape

torch.Size([1, 2048, 128])