In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [3]:
class PositionalEncoding(nn.Module):
    """位置编码"""
    # num_hiddens表示嵌入维度
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        # batch_size是1，每一行是一个被嵌入后的token
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)  # 偶数列用sin
        self.P[:, :, 1::2] = torch.cos(X)  # 奇数列用cos

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        # 加Dropout是为了防止模型对position太过敏感
        return self.dropout(X)