In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [39]:
class PointNetConv(nn.Module):
    def __init__(self, in_features, out_features, kernel_size=(3,3), padding=(1,1), stride=(1,1)):
        super(PointNetConv, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.out_features = out_features
        self.in_features = in_features
        self.weights = nn.Parameter(data=torch.Tensor(1, 1, out_features, 4), requires_grad=True)
        self.layer = nn.Conv1d(in_features, out_features, 1)
        
    def forward(self, x=None, X=None, A=None, E=None, M=None):
        unfolded_x = F.unfold(x, self.kernel_size, self.padding, self.stride)
        unfolded_X = F.unfold(X, self.kernel_size, self.padding, self.stride)
        unfolded_A = F.unfold(A, self.kernel_size, self.padding, self.stride)
        unfolded_E = F.unfold(E, self.kernel_size, self.padding, self.stride)
        unfolded_M = F.unfold(M, self.kernel_size, self.padding, self.stride)

        print(unfolded_x.shape)
        X_c = X.view(X.shape[0], X.shape[1], -1)
        A_c = A.view(A.shape[0], A.shape[1], -1)
        E_c = E.view(E.shape[0], E.shape[1], -1)
        M_c = M.view(M.shape[0], M.shape[1], -1)

        X_c = torch.repeat_interleave(X_c, self.kernel_size[0]*self.kernel_size[1], 1)
        A_c = torch.repeat_interleave(A_c, self.kernel_size[0]*self.kernel_size[1], 1)
        E_c = torch.repeat_interleave(E_c, self.kernel_size[0]*self.kernel_size[1], 1)
        M_c = torch.repeat_interleave(M_c, self.kernel_size[0]*self.kernel_size[1], 1)
        print(X_c.shape)
        
        x_ = unfolded_x.unsqueeze(3)           #.permute((0,2,1)).unsqueeze(3)
        X_ = (unfolded_X - X_c).unsqueeze(-1)  #.permute((0,2,1)).unsqueeze(3)
        A_ = (unfolded_A - A_c).unsqueeze(-1)  #.permute((0,2,1)).unsqueeze(3)
        E_ = (unfolded_E - E_c).unsqueeze(-1)  #.permute((0,2,1)).unsqueeze(3)

        ####
        first_feature = x_ * torch.cos(A_) * torch.cos(E_) - X_c.unsqueeze(-1)
        second_feature = x_ * torch.cos(A_) * torch.sin(E_)
        third_feature = x_ * torch.sin(A_)

        M_valid = M_c * unfolded_M
        M_valid = M_valid.unsqueeze(1)
        print(M_valid.shape)
        ####

        print(x_.shape,first_feature.shape)
        total_features = torch.cat([x_, 
                                    first_feature, 
                                    second_feature, 
                                    third_feature], axis=3)
                        
        total_features = total_features.permute((0, 3, 1, 2)) 
        total_features_flattened = total_features.reshape((total_features.shape[0], total_features.shape[1], -1))
        out = self.layer(total_features_flattened)
        out = out.reshape((total_features.shape[0],
                           out.shape[1],
                           total_features.shape[2],
                           total_features.shape[3]))
        print(out.shape)
        out = out * M_valid
        out = torch.max(out, dim=2)[0].reshape(out.shape[0],
                                               out.shape[1],
                                               x.shape[2],
                                               x.shape[3])
        #out = out * M_valid
        print(out.shape)
        
        # weights = torch.repeat_interleave(self.weights, self.kernel_size[0]*self.kernel_size[1], 2)
        # total_features = torch.repeat_interleave(total_features, self.out_features // self.in_features, 2)
        # print(total_features.shape, weights.shape)
        # total_features = weights * total_features
        # print(total_features.shape)
        # total_features = torch.sum(total_features, axis=3)
        # print(total_features.shape)
        # total_features = total_features.view(total_features.shape[0], total_features.shape[1], -1, 9)
        # print(total_features.shape)
        # total_features, _ = torch.max(total_features, dim=3, keepdim=True)
        # total_features = total_features.squeeze(3)
        # print(total_features.shape)
        # total_features = total_features.permute((0,2,1))
        # total_features = total_features.view(total_features.shape[0],total_features.shape[1],x.shape[2],-1)
        # print(total_features.shape)

In [40]:
point_net = PointNetConv(in_features=4, out_features=64)
x = torch.randint(0,10,(8,1,16,1024), dtype=torch.float32)
X = torch.randint(0,10,(8,1,16,1024), dtype=torch.float32)
A = torch.randint(0,10,(8,1,16,1024), dtype=torch.float32)
E = torch.randint(0,10,(8,1,16,1024), dtype=torch.float32)
M = torch.randint(0,10,(8,1,16,1024), dtype=torch.float32)
point_net.forward(x, X, A, E, M)

torch.Size([8, 9, 16384])
torch.Size([8, 9, 16384])
torch.Size([8, 1, 9, 16384])
torch.Size([8, 9, 16384, 1]) torch.Size([8, 9, 16384, 1])
torch.Size([8, 64, 9, 16384])
torch.Size([8, 64, 16, 1024])


In [None]:
point_net = PointNetConv(in_features=64, out_features=128)
x = torch.randint(0,10,(2,64,16,128),dtype=torch.float32)
X = torch.randint(0,10,(2,64,16,128),dtype=torch.float32)
A = torch.randint(0,10,(2,64,16,128),dtype=torch.float32)
E = torch.randint(0,10,(2,64,16,128),dtype=torch.float32)
point_net.forward(x, X, A, E)