In [11]:
import numpy as np
import torch
import torch.nn as nn


class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, G: int, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
        """
        Learnable Fourier Features from https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)
        Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
        Computes the positional encoding of a tensor of shape [N, G, M]
        :param G: positional groups (positions in different groups are independent)
        :param M: each point has a M-dimensional positional values
        :param F_dim: depth of the Fourier feature dimension
        :param H_dim: hidden layer dimension
        :param D: positional encoding dimension
        :param gamma: parameter to initialize Wr
        """
        super().__init__()
        self.G = G # 3
        self.M = M # 17
        self.F_dim = F_dim # 768
        self.H_dim = H_dim # 32
        self.D = D # 768
        self.gamma = gamma # 10

        # Projection matrix on learned lines (used in eq. 2)
        self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
        # MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6)
        self.mlp = nn.Sequential(
            nn.Linear(self.F_dim, self.H_dim, bias=True),
            nn.GELU(),
            nn.Linear(self.H_dim, self.D // self.G)
        )

        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x):
        """
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        """
        N, G, M = x.shape
        # Step 1. Compute Fourier features (eq. 2)
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
        # Step 2. Compute projected Fourier features (eq. 6)
        print(F.size())
        Y = self.mlp(F)
        print(Y.size())
        # Step 3. Reshape to x's shape
        PEx = Y.reshape((N, self.D))
        return PEx


if __name__ == '__main__':
    G = 2
    M = 2
    x = torch.randn((30, G, M)) # 30個點，2組(正負樣本)，座標
    enc = LearnableFourierPositionalEncoding(G, M, 768, 32, 768, 10)
    pex = enc(x)
    print(pex.shape)

torch.Size([30, 2, 768])
torch.Size([30, 2, 384])
torch.Size([30, 768])


In [1]:
import numpy as np
import torch
seg = torch.ones(1,128,128,128)
l = len(torch.where(seg == 1)[0])
print(l)
sample = np.random.choice(np.arange(l), 10, replace=True) # 從範圍為 [0, l) 的整數中隨機選取 10 個數字（可能有重複）
print(sample)
x = torch.where(seg == 1)[1][sample].unsqueeze(1)
y = torch.where(seg == 1)[3][sample].unsqueeze(1)
z = torch.where(seg == 1)[2][sample].unsqueeze(1)
print(z)
point_coord = torch.cat([x, y, z], dim=1).unsqueeze(1).float() 

foo = torch.randn(1,20,3)
point_coord = point_coord.transpose(0,1)
point_coord = torch.cat([point_coord,foo],dim=1)
print(point_coord.size())

point_coord.reshape(1,1,1,-1,3).size()

2097152
[ 940785 1887671  907095  510941 1211251 1157562 1431325 1086533  838822
   24127]
tensor([[ 53],
        [ 27],
        [ 46],
        [ 23],
        [118],
        [ 83],
        [ 46],
        [ 40],
        [ 25],
        [ 60]])
torch.Size([1, 30, 3])


torch.Size([1, 1, 1, 30, 3])

In [2]:
x = torch.randn(3, 2)
print(x)
torch.where(x > 0)

tensor([[ 0.0637, -0.6329],
        [-0.6641, -1.6162],
        [-1.0634,  1.4174]])


(tensor([0, 2]), tensor([0, 1]))

In [3]:
x = torch.randn(10,1)
y = torch.randn(10,1)
z = torch.randn(10,1)
points = torch.cat([x, y, z], dim=1).unsqueeze(1).float()
print(points.size())
points= points.transpose(0,1)
a = torch.randn(1,10,3)
print(a.size())
print(points.size())
torch.cat([a,points],dim=1).size()

torch.Size([10, 1, 3])
torch.Size([1, 10, 3])
torch.Size([1, 10, 3])


torch.Size([1, 20, 3])

In [4]:
import torch.nn.functional as F
img = torch.randn(1,3, 128,128,128)
# img[:, 0].size()
F.interpolate(img,scale_factor=0.5,mode='trilinear').size()

torch.Size([1, 3, 64, 64, 64])