In [1]:
from torch import nn
import torch 
import math

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim, dropout, max_len = 5000):
        super(PositionalEncoding, self).__init__()

        if dim % 2 != 0 :
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dim (got dim={:d})".format(dim))
        """
        构建位置编码pe
        pe公式为：
        PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
        """
        pe = torch.zeros(max_len, dim) # max_len 是解码器生成句子的最长长度
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype = torch.float) *
                              -(math.log(10000.0) / dim))) 
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        # 扩展batchsize维度  这个维度的 size 是 1
        pe = pe.unsqueeze(1)
        self.drop_out = nn.Dropout(p= dropout)
        self.dim = dim
    
    def forward(self, emb, step = None):
        # 保持嵌入张量数值的稳定性。当位置编码与嵌入值相加时，二者的数值范围相近，防止位置编码对嵌入值造成过大的干扰。
        # emb 的形状通常是 [seq_len, batch_size, dim]
        emb = emb * math.sqrt(self)
        if step == None:
            # pe 取出来的 形状为 [seq_len, 1, dim] 广播机制 扩展 [seq_len, batch_size, dim] 匹配
            emb = emb + self.pe[:emb.size(0)]
        else:
            emb = emb + self.pe[step]
        emb = self.drop_out(emb)
        return emb
                                