In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
import torch
import torch.nn as nn
import math

In [12]:
class PositionEncoding(nn.Module):
    def __init__(self,max_pos,embed_dim):
        super(PositionEncoding,self).__init__()  #最大位置长度和词向量的维度
        PE = torch.zeros(max_pos,embed_dim)    #保存位置编码的数组
        #生成从零到max_pos-1的位置数组pos
        pos = torch.arange(0, max_pos).unsqueeze(1).float()
        #序列，jiangewei2，对应2i
        multi_term= torch.arange(0,embed_dim,2).float()
        #e^(2i*(-log(10000/d)))
        multi_term = torch.exp(multi_term *(-math.log(1000.0)/embed_dim))

        PE[:,0::2]=torch.sin(pos *multi_term)
        PE[:,1::2] = torch.cos(pos * multi_term)
        #将数组PE注册为不需要梯度更新的缓存数组
        self.register_buffer('PE',PE.unsqueeze(0))
#前向传播函数，
    def forward(self,x):
        return x+ self.PE[:,x.size(1)].clone().detach()

In [13]:
if __name__ == "__main__":
    max_pos =10 #序长度
    embed_dim =4

    model = PositionEncoding(max_pos,embed_dim)
#样本个数为2，长度为5，维度为4
    x =torch.zeros(2,5,embed_dim)
    
    output = model(x)
    print("x:")
    print(x)
    print("PE:")
    print(model.PE)
    print("output:")
    print(output)


x:
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])
PE:
tensor([[[ 0.0000,  1.0000,  0.0000,  1.0000],
         [ 0.8415,  0.5403,  0.0316,  0.9995],
         [ 0.9093, -0.4161,  0.0632,  0.9980],
         [ 0.1411, -0.9900,  0.0947,  0.9955],
         [-0.7568, -0.6536,  0.1262,  0.9920],
         [-0.9589,  0.2837,  0.1575,  0.9875],
         [-0.2794,  0.9602,  0.1886,  0.9821],
         [ 0.6570,  0.7539,  0.2196,  0.9756],
         [ 0.9894, -0.1455,  0.2503,  0.9682],
         [ 0.4121, -0.9111,  0.2808,  0.9598]]])
output:
tensor([[[-0.9589,  0.2837,  0.1575,  0.9875],
         [-0.9589,  0.2837,  0.1575,  0.9875],
         [-0.9589,  0.2837,  0.1575,  0.9875],
         [-0.9589,  0.2837,  0.1575,  0.9875],
         [-0.9589,  0.2837,  0.1575,  0.9875]],

  