In [2]:
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
import torch.nn.functional as F


def square_distance(src, dst):
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)


def index_points(points, idx):
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)


def farthest_point_sample(xyz, npoint):
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        distance = torch.min(distance, dist)
        farthest = torch.max(distance, -1)[1]
    return centroids


def query_ball_point(radius, nsample, xyz, new_xyz):
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx


def sample_and_group(npoint, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint

    fps_idx = farthest_point_sample(xyz, npoint)  # [B, npoint]

    new_xyz = index_points(xyz, fps_idx)
    new_points = index_points(points, fps_idx)

    dists = square_distance(new_xyz, xyz)  # B x npoint x N
    idx = dists.argsort()[:, :, :nsample]  # B x npoint x K

    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
    return new_xyz, new_points

In [38]:
class Local_op(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        print('what is x like: ', x.shape)
        b, n, s, d = x.size()  # torch.Size([32, 512, 32, 6])
        
        x = x.permute(0, 1, 3, 2)
        x = x.reshape(-1, d, s)
        print('what is x like: ', x.shape)
        
        batch_size, _, N = x.size()
        print('Start change: ')
        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N
        print('what is x like: ', x.shape)

        x = self.relu(self.bn2(self.conv2(x)))  # B, D, N
        print('what is x like: ', x.shape)

        x = torch.max(x, 2)[0]
        print('what is x like: ', x.shape)
        
        x = x.view(batch_size, -1)
        x = x.reshape(b, n, -1).permute(0, 2, 1)
        return x


In [41]:
lo_input = torch.ones([4, 256, 32, 256])
lo = Local_op(256, 256)
out = lo(lo_input)
print(out.shape)

what is x like:  torch.Size([4, 256, 32, 256])
what is x like:  torch.Size([1024, 256, 32])
Start change: 
what is x like:  torch.Size([1024, 256, 32])
what is x like:  torch.Size([1024, 256, 32])
what is x like:  torch.Size([1024, 256])
torch.Size([4, 256, 256])


In [36]:
class SA_Layer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight
        self.v_conv = nn.Conv1d(channels, channels, 1)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x_q = self.q_conv(x).permute(0, 2, 1)  # b, n, c
        print('what is x_q like: ', x_q.shape)
        
        x_k = self.k_conv(x)  # b, c, n
        print('what is x_k like: ', x_k.shape)
        
        x_v = self.v_conv(x)
        print('what is x_v like: ', x_v.shape)
        
        energy = x_q @ x_k  # b, n, n
        print('what is energy like: ', energy.shape)
        
        attention = self.softmax(energy)
        print('what is attention like: ', attention.shape)
        
        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
        print('what is attention like: ', attention.shape)
        
        x_r = x_v @ attention  # b, c, n
        print('what is x_r like: ', x_r.shape)
        
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        print('what is x_r like: ', x_r.shape)
        
        x = x + x_r
        print('what is x_r like: ', x_r.shape)
        
        return x

In [37]:
input_sa = torch.ones((4, 256, 256))
sa = SA_Layer(256)
out = sa(input_sta)
print(out.shape)

what is x_q like:  torch.Size([4, 256, 64])
what is x_k like:  torch.Size([4, 64, 256])
what is x_v like:  torch.Size([4, 256, 256])
what is energy like:  torch.Size([4, 256, 256])
what is attention like:  torch.Size([4, 256, 256])
what is attention like:  torch.Size([4, 256, 256])
what is x_r like:  torch.Size([4, 256, 256])
what is x_r like:  torch.Size([4, 256, 256])
what is x_k like:  torch.Size([4, 64, 256])
torch.Size([4, 256, 256])


In [32]:
class StackedAttention(nn.Module):
    def __init__(self, channels=256):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm1d(channels)
        self.bn2 = nn.BatchNorm1d(channels)

        self.sa1 = SA_Layer(channels)
        self.sa2 = SA_Layer(channels)
        self.sa3 = SA_Layer(channels)
        self.sa4 = SA_Layer(channels)

        self.relu = nn.ReLU()

    def forward(self, x):
        batch_size, _, N = x.size()
        print('what is x like: ', x.shape)
    
        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N
        print('what is x like: ', x.shape)
        
        x = self.relu(self.bn2(self.conv2(x)))
        print('what is x like: ', x.shape)
        
        print("Start the attention module!")
        x1 = self.sa1(x)
        print('what is x like: ', x.shape)
        
        x2 = self.sa2(x1)
        print('what is x like: ', x.shape)
        
        x3 = self.sa3(x2)
        print('what is x like: ', x.shape)
        
        x4 = self.sa4(x3)
        print('what is x like: ', x.shape)
        

        x = torch.cat((x1, x2, x3, x4), dim=1)
        print('what is x like: ', x.shape)
        

        return x

In [33]:
input_sta = torch.ones((4, 256, 256))
sta = StackedAttention()
out = sta(input_sta)
print(out.shape)

what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 256, 256])
Start the attention module!
what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 1024, 256])
torch.Size([4, 1024, 256])


In [28]:
class PCT_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # todo: here

        d_points = 3
        self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
        self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
        self.pt_last = StackedAttention()

        self.relu = nn.ReLU()
        self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
                                       nn.BatchNorm1d(1024),
                                       nn.LeakyReLU(negative_slope=0.2))

    def forward(self, x):
        xyz = x[..., 0:3]
        x = x.permute(0, 2, 1)
        batch_size, _, _ = x.size()
        # set the output dimension
        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N
        print('what is x like: ', x.shape)
        x = self.relu(self.bn2(self.conv2(x)))  # B, D, N
        print('what is x like: ', x.shape)
        x = x.permute(0, 2, 1)
        print('what is x like: ', x.shape)
        new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)
        print('new_xyz, new_feature: ', new_xyz.shape, new_feature.shape)
        feature_0 = self.gather_local_0(new_feature)
        print('feature_0: ', feature_0.shape)
        feature = feature_0.permute(0, 2, 1)
        print('feature: ', feature.shape)
        new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature)
        print("Again new_xyz, new_feature: ", new_xyz.shape, new_feature.shape)
        feature_1 = self.gather_local_1(new_feature)
        print('feature_1: ', feature_1.shape)
        
        x = self.pt_last(feature_1)
        print('what is x like: ', x.shape)
        x = torch.cat([x, feature_1], dim=1)
        print('after cat what is x like: ', x.shape)
        x = self.conv_fuse(x)
        print('after conv_fuse what is x like: ', x.shape)
        x = torch.max(x, 2)[0]
        print('after torch.max: ', x.shape)
        x = x.view(batch_size, -1)
        return x

In [29]:
input_points = torch.ones((4, 1024, 3))

In [30]:
net = PCT_Encoder()
out = net(input_points)
print(out.shape)

what is x like:  torch.Size([4, 64, 1024])
what is x like:  torch.Size([4, 64, 1024])
what is x like:  torch.Size([4, 1024, 64])
new_xyz, new_feature:  torch.Size([4, 512, 3]) torch.Size([4, 512, 32, 128])
feature_0:  torch.Size([4, 128, 512])
feature:  torch.Size([4, 512, 128])
Again new_xyz, new_feature:  torch.Size([4, 256, 3]) torch.Size([4, 256, 32, 256])
feature_1:  torch.Size([4, 256, 256])
what is x like:  torch.Size([4, 1024, 256])
after cat what is x like:  torch.Size([4, 1280, 256])
after conv_fuse what is x like:  torch.Size([4, 1024, 256])
after torch.max:  torch.Size([4, 1024])
torch.Size([4, 1024])
