In [3]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        """
        Args:
            d_model: 词向量的维度
            max_len: 位置嵌入的最大序列长度
        """
        super(PositionalEncoding, self).__init__()
        
        # 初始化位置嵌入矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 位置索引（0, 1, ..., max_len-1）
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # 计算频率项

        # 偶数维使用sin，奇数维使用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置：sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置：cos
        
        # 添加额外的维度使其适合与输入相加，并冻结参数
        pe = pe.unsqueeze(0).transpose(0, 1)  # Shape: (max_len, 1, d_model)
        self.register_buffer('pe', pe)  # 将pe注册为buffer，确保在训练时不会被更新

    def forward(self, x):
        """
        Args:
            x: 输入张量，形状为 (seq_len, batch_size, d_model)
        """
        # 将位置嵌入与输入相加
        print(self.pe[:x.size(0), :].shape)

        x = x + self.pe[:x.size(0), :]
        return x

# 测试位置嵌入
seq_len, batch_size, d_model = 10, 32, 512  # 定义序列长度、批量大小和嵌入维度
x = torch.zeros(seq_len, batch_size, d_model)  # 创建一个零输入
pos_encoder = PositionalEncoding(d_model)  # 初始化位置编码器
x = pos_encoder(x)  # 应用位置嵌入

print(x.shape)  # 输出 (seq_len, batch_size, d_model)

torch.Size([10, 1, 512])
torch.Size([10, 32, 512])


In [6]:
import torch
d_model = 512
max_len = 10
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
a = torch.arange(0, d_model, 2)
print(position)
print(a)
(a*position).shape

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]])
tensor([  0,   2,   4,   6,   8,  10,  12,  14,  16,  18,  20,  22,  24,  26,
         28,  30,  32,  34,  36,  38,  40,  42,  44,  46,  48,  50,  52,  54,
         56,  58,  60,  62,  64,  66,  68,  70,  72,  74,  76,  78,  80,  82,
         84,  86,  88,  90,  92,  94,  96,  98, 100, 102, 104, 106, 108, 110,
        112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138,
        140, 142, 144, 146, 148, 150, 152, 154, 156, 158, 160, 162, 164, 166,
        168, 170, 172, 174, 176, 178, 180, 182, 184, 186, 188, 190, 192, 194,
        196, 198, 200, 202, 204, 206, 208, 210, 212, 214, 216, 218, 220, 222,
        224, 226, 228, 230, 232, 234, 236, 238, 240, 242, 244, 246, 248, 250,
        252, 254, 256, 258, 260, 262, 264, 266, 268, 270, 272, 274, 276, 278,
        280, 282, 284, 286, 288, 290, 292, 294, 296, 298, 300, 302, 304, 306,
 

torch.Size([10, 256])