In [196]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_context(context="talk")

In [197]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.lut = nn.Embedding(vocab, d_model)  # 定义嵌入层
        self.d_model = d_model  # 定义嵌入层维度

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)  # 嵌入并缩放

In [198]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)  # 定义随机失活

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)  # 初始化位置编码张量
        position = torch.arange(0, max_len).unsqueeze(1)  # 位置序号张量
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))  # 频率
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置的位置编码
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置的位置编码
        pe = pe.unsqueeze(0)  # 在第0维增加一个维度
        self.register_buffer(
            "pe", pe
        )  # 将位置编码张量pe注册为buffer，使其称为模型的一部分，模型保存时缓冲区中内容也会一起被保存起来

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].detach()  # 将位置编码张量与输入张量相加
        return self.dropout(x)

In [199]:
d_model = 512
vocab = 10000
embeddings = Embeddings(d_model, vocab)
positional_encoding = PositionalEncoding(d_model, 0.1)
x = torch.tensor([[1, 2, 3, 1], [6, 5, 89, 0]])
x = embeddings(x)
x = positional_encoding(x)
print(x)
print(x.shape)

tensor([[[ 11.2272, -13.6934, -16.0333,  ...,  30.8006,   0.9454,  -7.0989],
         [  3.5160,  -6.1167,  15.4376,  ...,  14.4768,  -6.3042, -20.1314],
         [  7.0921, -31.8241,   3.4255,  ...,   3.7689,  11.1341,  18.0814],
         [ 11.3840, -15.9045,  -0.0000,  ...,  30.8006,   0.9457,  -7.0989]],

        [[ 12.5045,  13.5130,  25.2645,  ..., -35.6063,  49.9737,  -7.7355],
         [ 11.0413, -13.2639, -48.0724,  ...,   0.0000,  -0.0000, -10.1227],
         [ 28.4697,  -0.0000, -12.4270,  ...,  19.4462,  44.1881, -12.1257],
         [ -0.3193, -23.7343, -20.2613,  ...,   0.8307,  31.1339,   0.0000]]],
       grad_fn=<MulBackward0>)
torch.Size([2, 4, 512])
