In [8]:
import torch

a = torch.randn(1, 3, 2)

diffs = a.unsqueeze(2) - a 

print(diffs.shape, diffs)

print(a)

print(a[:, 0, :] - a[:, 1, :])

torch.Size([1, 3, 3, 2]) tensor([[[[ 0.0000,  0.0000],
          [ 1.2818,  3.5853],
          [ 0.7203,  1.6506]],

         [[-1.2818, -3.5853],
          [ 0.0000,  0.0000],
          [-0.5616, -1.9347]],

         [[-0.7203, -1.6506],
          [ 0.5616,  1.9347],
          [ 0.0000,  0.0000]]]])
tensor([[[-0.2038,  2.7274],
         [-1.4856, -0.8579],
         [-0.9240,  1.0768]]])
tensor([[1.2818, 3.5853]])


In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Reduce
from pytorch3d.ops import knn_points


#ARPE: Absolute Relative Position Encoding
class ARPE(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, npoints=1024):
        super(ARPE, self).__init__()

        N0 = 512
        k0 = 32
        #self.k = int(k0 * npoints / N0)
        self.k = 3


        self.lin1 = nn.Linear(2*in_channels, 2*in_channels)
        self.lin2 = nn.Linear(2*in_channels, out_channels)

        self.bn1 = nn.BatchNorm1d(2*in_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.max_pooling_layer = Reduce('bn k f -> bn 1 f', 'max')
     
    def forward(self, x):
    
        B, N, C = x.shape  # B: batch size, N: number of points, C: channels

        knn = knn_points(x, x, K=self.k, return_nn=True)[2] # B, N, K, C

        diffs = x.unsqueeze(2) - knn  # B, N, K, C

        x = torch.cat([x.unsqueeze(2).repeat(1, 1, self.k, 1), diffs], dim=-1) # B, N, K, 2*C
        x = F.elu(self.bn1(self.lin1(x.view(B*N, self.k, 2*C)).transpose(1,2)).transpose(1,2)) # B*N, K, 2*C
        x = self.max_pooling_layer(x).squeeze(2) # B*N, 1, 2*C -> B*N, 2*C
        x = F.elu(self.bn2(self.lin2(x.view(B, N, 2*C)).transpose(1,2)).transpose(1,2)) # B, N, out_channels

        return x # B, N, 2*C



x = torch.randn(32, 1024, 3)
arpe = ARPE()
y = arpe(x)
print(y.shape)

torch.Size([32, 1024, 32])


In [29]:
test1 = torch.Tensor([[[0,1,0],[1,1,0],[5,4,2],[1,-1,0],[5,4,5]]])

#print(test1.shape, test1)

res1 = arpe(test1)

#print(res1.shape, res1)

In [34]:
def channel_shuffle(x, groups):
    B, N, C = x.shape
    x = x.reshape(B,N,C//groups,groups).permute(0,1,3,2).reshape(B,N,C)
    return x

x = torch.Tensor([[[0,1,2,3,4,5],[10,11,12,13,14,15]]])
print(x, "\n",channel_shuffle(x, 2))

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [10., 11., 12., 13., 14., 15.]]]) 
 tensor([[[ 0.,  2.,  4.,  1.,  3.,  5.],
         [10., 12., 14., 11., 13., 15.]]])


In [37]:
class GSA(nn.Module):
    def __init__(self, channels = 64, groups=1) -> None:
        super(GSA, self).__init__()

        self.channels = channels
        self.groups = groups
        assert self.channels % self.groups == 0, "C must be divisible by groups"
        self.cg = self.channels // self.groups
        self.linears = [nn.Linear(self.cg, self.cg) for _ in range(self.groups)]
        self.gn = nn.GroupNorm(self.groups, self.channels)

    def forward(self, x):

        B, N, C = x.shape

        xin = x # B, N, C

        #grouped_x = x.reshape(B, N, C//self.groups, self.groups) # B, N, C//groups, groups

        #Si può vettorizzare?
        x_g =[]
        for i in range(self.groups):
            x = self.linears[i](xin[:,:,i*self.cg:(i+1)*self.cg]) # B, N, C//groups
            x = F.scaled_dot_product_attention(x,x,F.elu(x))
            x_g.append(x)
        x = torch.cat(x_g, dim=-1) # B, N, C

        x = self.gn((channel_shuffle(x, self.groups) + xin).transpose(1,2)).transpose(1,2) # B, N, C

        return x
    

gsa = GSA(groups=2)
x = torch.randn(32, 1024, 64)
y = gsa(x)
print(y.shape)

torch.Size([32, 1024, 64])
