In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# *Sin Cos 1D Embedding*

In [34]:
def GetPositionEmbedding(sequence_len, embedding_len):
    div = 10000 ** (torch.arange(0, embedding_len // 2) * 2 / embedding_len)
    positions = torch.arange(0, sequence_len).unsqueeze(1)
    pos_emb = torch.zeros(sequence_len, embedding_len)
    pos_emb[:, 0::2] = torch.sin(positions / div)
    pos_emb[:, 1::2] = torch.cos(positions / div)
    return pos_emb

print(GetPositionEmbedding(5, 3))

tensor([[ 0.0000,  1.0000,  0.0000],
        [ 0.8415,  0.5403,  0.8415],
        [ 0.9093, -0.4161,  0.9093],
        [ 0.1411, -0.9900,  0.1411],
        [-0.7568, -0.6536, -0.7568]])


# *Trainable 1D Embedding*

In [8]:
def GetPositionEmbedding(sequence_len, embedding_len):
    pos_emb = nn.Embedding(sequence_len, embedding_len).weight
    nn.init.constant_(pos_emb, 0.0)
    return pos_emb

print(GetPositionEmbedding(5, 3))

Parameter containing:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], requires_grad=True)


# *Sin Cos 2D Embedding*

In [43]:
# Now it takes in index instead of just length
def GetPositionEmbedding1D(sequence_indices, embedding_len):
    div = 10000 ** (torch.arange(0, embedding_len // 2) * 2 / embedding_len)
    positions = sequence_indices.unsqueeze(1)
    pos_emb = torch.zeros(sequence_indices.shape[0], embedding_len)
    pos_emb[:, 0::2] = torch.sin(positions / div)
    pos_emb[:, 1::2] = torch.cos(positions / div)
    return pos_emb

def GetPositionEmbedding2D(image_size, embedding_dim):
    pos_emb = torch.zeros(image_size*image_size, embedding_dim)
    coord = torch.stack(torch.meshgrid(torch.arange(0, image_size), torch.arange(0, image_size))) # 2 * W * H
    coord = coord.view(coord.shape[0], -1)
    height_embedding = GetPositionEmbedding1D(coord[0], embedding_dim // 2)  # embedding for 0000111122223333
    width_embedding = GetPositionEmbedding1D(coord[1], embedding_dim // 2)   # embedding for 0123012301230123
    pos_emb[:, embedding_dim//2:] = height_embedding
    pos_emb[:, :embedding_dim//2] = width_embedding
    # Now each embedding has both width and height information
    
    return pos_emb
    
print(GetPositionEmbedding2D(5, 4))

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0000,  1.0000],
        [ 0.9093, -0.4161,  0.0000,  1.0000],
        [ 0.1411, -0.9900,  0.0000,  1.0000],
        [-0.7568, -0.6536,  0.0000,  1.0000],
        [ 0.0000,  1.0000,  0.8415,  0.5403],
        [ 0.8415,  0.5403,  0.8415,  0.5403],
        [ 0.9093, -0.4161,  0.8415,  0.5403],
        [ 0.1411, -0.9900,  0.8415,  0.5403],
        [-0.7568, -0.6536,  0.8415,  0.5403],
        [ 0.0000,  1.0000,  0.9093, -0.4161],
        [ 0.8415,  0.5403,  0.9093, -0.4161],
        [ 0.9093, -0.4161,  0.9093, -0.4161],
        [ 0.1411, -0.9900,  0.9093, -0.4161],
        [-0.7568, -0.6536,  0.9093, -0.4161],
        [ 0.0000,  1.0000,  0.1411, -0.9900],
        [ 0.8415,  0.5403,  0.1411, -0.9900],
        [ 0.9093, -0.4161,  0.1411, -0.9900],
        [ 0.1411, -0.9900,  0.1411, -0.9900],
        [-0.7568, -0.6536,  0.1411, -0.9900],
        [ 0.0000,  1.0000, -0.7568, -0.6536],
        [ 0.8415,  0.5403, -0.7568

# *Relative Bias 2D Embedding*

In [32]:
# Used in Vit attention
# q*kT + B  (B: embedding)
def GetPositionEmbedding(image_size, embedding_len):
    coord = torch.stack(torch.meshgrid(torch.arange(0, image_size), torch.arange(0, image_size)))
    
    
    
    coord = coord.view(coord.shape[0], -1)
        
    bias = coord[:, :, None] - coord[:, None, :] + image_size - 1
    return bias[0] * ((image_size-1) * 2 + 1) + bias[1]
    
    
print(GetPositionEmbedding(9, 3))

tensor([[12, 11, 10,  7,  6,  5,  2,  1,  0],
        [13, 12, 11,  8,  7,  6,  3,  2,  1],
        [14, 13, 12,  9,  8,  7,  4,  3,  2],
        [17, 16, 15, 12, 11, 10,  7,  6,  5],
        [18, 17, 16, 13, 12, 11,  8,  7,  6],
        [19, 18, 17, 14, 13, 12,  9,  8,  7],
        [22, 21, 20, 17, 16, 15, 12, 11, 10],
        [23, 22, 21, 18, 17, 16, 13, 12, 11],
        [24, 23, 22, 19, 18, 17, 14, 13, 12]])
