In [3]:
import torch
from torch import nn

# 位置编码模块
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / \
            torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)  # 偶数维
        self.P[:, :, 1::2] = torch.cos(X)  # 奇数维

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

# 初始化：维度为8，序列长度为4，dropout设为0以便观察
encoding_dim, num_steps = 8, 4
pos_encoding = PositionalEncoding(encoding_dim, dropout=0)
pos_encoding.eval()

# 构造一个值全为1的输入序列 X，形状为 (1, 4, 8)
X = torch.ones((1, num_steps, encoding_dim))

# 取前4个位置的编码向量
P = pos_encoding.P[:, :num_steps, :]

# 打印输入
print("输入 X:")
print(X[0])

# 打印位置编码
print("\n位置编码 P:")
print(P[0])

# 打印相加后的结果
print("\n相加后的结果 X + P:")
print((X + P)[0])


输入 X:
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

位置编码 P:
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
          9.9955e-01,  3.0000e-03,  1.0000e+00]])

相加后的结果 X + P:
tensor([[1.0000, 2.0000, 1.0000, 2.0000, 1.0000, 2.0000, 1.0000, 2.0000],
        [1.8415, 1.5403, 1.0998, 1.9950, 1.0100, 1.9999, 1.0010, 2.0000],
        [1.9093, 0.5839, 1.1987, 1.9801, 1.0200, 1.9998, 1.0020, 2.0000],
        [1.1411, 0.0100, 1.2955, 1.9553, 1.0300, 1.9996, 1.0030, 2.0000]])


In [None]:
# col3 [2.0000, 1.9950, 1.9801, 1.9553]