In [8]:
#import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from point_transformer_pytorch import PointTransformerLayer

In [9]:
# create random variable
N = 512 # number of points
feature_size = 13 
batch_size = 16
seq_len = 1
x = torch.randn(batch_size, seq_len, N, feature_size)  # (batch_size, num_points, in_channels)

In [15]:
# input (I, 13)
# input neighbours (J, 13)
class LrgNet(nn.Module):
    def __init__(self, batch_size=32,
                 seq_len=1, # Number of region growing steps unrolled
                 num_inlier_points=512,
                 num_neighbour_points=512,
                 feature_dim=13):
        super(LrgNet, self).__init__()
        self.B1 = [128, 128]
        self.B2 = [128, 256, 512, 1024]
        self.B3 = [512, 256, 128]
        
        # Do I include this in the PointTransformer blocks?
        # Initial up-dimension layers for inlier and neighbour points (to be updated)
        self.fc1 = nn.Linear(feature_dim, self.B1[0])
        self.neigbour_fc1 = nn.Linear(feature_dim, self.B1[0])

        # Point Transformer Layers for inlier set
        for i in range(len(self.B1)-1):
            inlier_layer = PointTransformerLayer(dim=self.B1[i], pos_mlp_hidden_dim=32, attn_mlp_hidden_mult=4, num_neighbors = 16)
            linear_layer = nn.Linear(self.B1[i], self.B1[i+1])
            neighbour_layer = PointTransformerLayer(dim=self.B1[i], pos_mlp_hidden_dim=32, attn_mlp_hidden_mult=4, num_neighbors = 16)
            linear_layer_neighbour = nn.Linear(self.B1[i], self.B1[i+1])

            ######### Need a diagram drawn with dimensions #########
            # inlier and linear have different input shapes
            setattr(self, 'B1_inlier_' + str(i), nn.Sequential(inlier_layer, linear_layer))
            setattr(self, 'B1_neighbour_' + str(i), nn.Sequential(neighbour_layer, linear_layer_neighbour))
        
        # Point Transformer Layers for outlier set
        for i in range(len(self.B2)-1):
            inlier_layer = PointTransformerLayer(dim=self.B2[i], pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, num_neighbors = 16)
            neighbour_layer = PointTransformerLayer(dim=self.B2[i], pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, num_neighbors = 16)
            linear_layer = nn.Linear(self.B2[i], self.B2[i+1])
            linear_layer_neighbour = nn.Linear(self.B2[i], self.B2[i+1])  
            setattr(self, 'B2_inlier_' + str(i), nn.Sequential(inlier_layer, linear_layer))
            setattr(self, 'B2_neighbour_' + str(i), nn.Sequential(neighbour_layer, linear_layer_neighbour))

        # Global Average Pooling
        ###### Does this work as intended ######
        # https://docs.pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool2d.html
        self.max_pool = torch.nn.AdaptiveMaxPool2d((1, 1))
        
        # B3
        for i in range(len(self.B3)-1):
            inlier_layer = PointTransformerLayer(dim=self.B3[i], pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, num_neighbors = 16)
            neighbour_layer = PointTransformerLayer(dim=self.B3[i], pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, num_neighbors = 16)
            linear_layer = nn.Linear(self.B3[i], self.B3[i+1])
            linear_layer_neighbour = nn.Linear(self.B3[i], self.B3[i+1])  
            setattr(self, 'B3_inlier_' + str(i), nn.Sequential(inlier_layer, linear_layer))
            setattr(self, 'B3_neighbour_' + str(i), nn.Sequential(neighbour_layer, linear_layer_neighbour))
        
        # Final classification layer
        self.remove_mask = nn.Linear(self.B3[-1], 2)  # Assuming binary classification
        self.add_mask = nn.Linear(self.B3[-1], 2)  # Assuming binary classification


    def forward(self, inlier_points, neighbour_points):
        p = inlier_points[-3:]
        p_neigh = neighbour_points[-3:]
        x = self.fc1(inlier_points[:-3])
        x_neigh = self.neigbour_fc1(neighbour_points[:-3])

        for i in range(len(self.B1)-1):
            inlier_layer = getattr(self, 'B1_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B1_neighbour_' + str(i))
            x = inlier_layer(x, p)
            x_neigh = neighbour_layer(x_neigh, p_neigh)

        for i in range(len(self.B2)-1):
            inlier_layer = getattr(self, 'B2_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B2_neighbour_' + str(i))
            x = inlier_layer(x, p)
            x_neigh = neighbour_layer(x_neigh, p_neigh)

        return x, x_neigh, p, p_neigh

        # Global Max Pooling
        x = x.permute(0, 3, 1, 2)
        x = self.max_pool(x)
        x = x.view(x.size(0), -1)
        x_neigh = x_neigh.permute(0, 3, 1, 2)
        x_neigh = self.max_pool(x_neigh)
        x_neigh = x_neigh.view(x_neigh.size(0), -1)

        # B3
        for i in range(len(self.B3)-1):
            inlier_layer = getattr(self, 'B3_inlier_' + str(i))
            neighbour_layer = getattr(self, 'B3_neighbour_' + str(i))
            x = inlier_layer(x, p)
            x_neigh = neighbour_layer(x_neigh, p_neigh)

        # Final classification
        remove_mask_logits = self.remove_mask(x)
        add_mask_logits = self.add_mask(x_neigh)

        return remove_mask_logits, add_mask_logits
    
model = LrgNet(batch_size=batch_size,
                seq_len=seq_len,
                num_inlier_points=N,
                num_neighbour_points=N,
                feature_dim=feature_size)
outlier_points = torch.randn(batch_size, seq_len, N, feature_size)  # (batch_size, num_points, in_channels)
x, x_neigh, p, p_neigh = model(x, outlier_points)

TypeError: forward() takes 2 positional arguments but 3 were given

In [7]:
import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4,
)

feats = torch.randn(1, 16, 128)
pos = torch.randn(1, 16, 3)
mask = torch.ones(1, 16).bool()

out = attn(feats, pos, mask = mask) # (1, 16, 128)

# print output shape
out.shape

TypeError: __init__() got an unexpected keyword argument 'dim_out'