# Transformer 模型
这篇notebook实现了Transformer模型中的主要方法，包括多头注意力，LayerNorm等，实现的Transformer，TransformerEncoder/Decoder，TransformerEncoder/DecoderLayer类有着和Pytorch实现差不多的参数和接口，不能保证可以混用，但是至少看起来和用起来是差不多的。代码当然也比Pytorch的实现简单了很多。这份文件和model.py的内容是一样的，只不过Notebook文件多了文字说明。让我们开始吧。

<img src="https://image.panwenbo.icu/blog20210728094709.jpg" alt="v2-22a369f0f1b0d542ced248dcb215b6e8_1440w" style="zoom:50%;" />

参考资料 

1. https://zhuanlan.zhihu.com/p/48731949 参考了主要结构和部分代码细节
2. https://pytorch.org/tutorials/beginner/translation_transformer.html 参考了位置编码的设计
3. ttps://pytorch.org/tutorials/beginner/transformer_tutorial.html 参考了不多
4. https://arxiv.org/abs/1706.03762 参考了部分细节




## 1. 实现Transformer基础结构
Transformer的网络的基本结构是一个多头Attention加上一个全连接网络，而这两个子层都被残差连接和LayerNorm所包裹，也就是子层的输入和输出会被加在一起并且使用LayerNorm归一化。

### 1.1 SublayerWrapper（图中的Add&Norm）
为了能够把输入和输出加在一起，整个transformer的所有结构的输出输入的形状都是一样的：（L序列长度，N批大小，`d_model`输出特征大小），在原论文实现中，`d_model=512`。输出在被加上输入之前还会使用一个Dropout处理一下。

<img src="https://image.panwenbo.icu/blog20210728095532.png" alt="截屏2021-07-28 上午9.55.24" style="zoom:20%;" />

In [1]:
import torch
import torch.nn as nn

class SublayerWrapper(nn.Module):
    def __init__(self, sub_layer, d_model, dropout_r):
        super(SublayerWrapper, self).__init__()
        self.dropout = nn.Dropout(dropout_r)
        self.layernorm = nn.LayerNorm(d_model)
        self.sub_layer = sub_layer
    
    def forward(self, *args, **kwargs):
        output = self.sub_layer(*args, **kwargs)
        return self.layernorm(args[0] + self.dropout(output))

In [2]:
# 简单测试一下
# 要留意一下包裹的那个模块有没有被注册到，只有模块本身是类的成员变量中或是
# 在ModuleList或ModuleDict中并且这个容器也是类的成员变量，这个模块的Parameters()
# 才会进入到model.parameters() 并且被优化器更新
wrapped_linear = SublayerWrapper(nn.Linear(5, 5), 5, 0.1)
print(wrapped_linear)
print(wrapped_linear(torch.rand(5)).shape)

SublayerWrapper(
  (dropout): Dropout(p=0.1, inplace=False)
  (layernorm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
  (sub_layer): Linear(in_features=5, out_features=5, bias=True)
)
torch.Size([5])


### 1.2 FeedForwardNet（图中的FeedForward）
这个FFN（FeedForwardNet）由两层全连接层组成，中间使用了relu函数和dropout。由于它一次只单独处理一个时间刻，一个batch上的一个长度为`d_model`的特征向量，这个两层的网络等于两个1x1卷积。

<img src="https://image.panwenbo.icu/blog20210728095705.png" alt="截屏2021-07-28 上午9.57.00" style="zoom:16%;" />

In [3]:
import torch.nn.functional as F

class FeedForwardNet(nn.Module):
    """A two-layer relu feedforward network following a multihead attention"""
    def __init__(self, d_model, dim_feedforward, dropout=0.1):
        super(FeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.fc2 = nn.Linear(dim_feedforward, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        return self.fc2(self.dropout(F.relu(self.fc1(input))))

In [4]:
# 简单测试一下
LENGTH = 5
BATCH_SIZE = 128
D_MODEL = 512
ffn = FeedForwardNet(512, 2048)
print(ffn(torch.rand(LENGTH, BATCH_SIZE, D_MODEL)).shape)

torch.Size([5, 128, 512])


### 1.3 Multihead Attention
要实现一个多头Attention，首先要实现一个单头Self-Attention:
$$
\alpha_{ij}=SoftMax(\frac{Q_iK_j^T}{\sqrt{d_k}})\\\\
A_i=\textbf{Attention}(Q_i,(K,V))= \sum_{j=1}^{d_k} \alpha_{ij}V_j
$$
这是整个Transformer中最为复杂的部分。要注意这里使用的矩阵乘法，并不是大小对上了就能得到正确的结果，还要考虑不同形式的矩阵乘法中到底发生了什么。同时`mask`和`key_padding_mask`的含义是不同的，`mask`代表了在不同的时间刻要遮住哪些key，`key_padding_mask`代表了对于不同batch中的输入序列，它们的padding字符都在哪里。

<img src="https://image.panwenbo.icu/blog20210529204909.png" alt="截屏2021-05-29 下午8.49.02" style="zoom:10%;" />

In [5]:
import torch
import math

def attention(query, key, value, mask=None, key_padding_mask=None, dropout=None):
    """Compute scale dot product attention for given Q, K, V

    Below S is the length of query, N is the batch size and E is the feature numbers,
    T is the length of key and value. Particularly in self attention, S=T; in source attention, S!=T.
    Assume in any condition, length of key and value is the same.

    :param query: Q tensor :math:`(S, N, E)` or `(S, N_HEAD, N, E)`
    :param key: K tensor :math:`(T, N, E)` or `(T, N_HEAD, N, E)`
    :param value: V tensor :math:`(T, N, E)` or `(T, N_HEAD, N, E)`
    :param mask: Mask of QKV tensor :math:`(S, T)` or `(N, N_HEAD, S, T)`, this mask 
    will be added onto scores directly.
    :param key_padding_mask: ByteTensor mask indicating padding tokens' place.
    place where is one shows there is a padding tokens and will be maksed.
    shape: :math:`(N, N_Head, 1, T)`
    :param dropout: dropout module will be applied defaults to None
    :return: Attention values with shape:math:`(S, N, E)` or `(*, N, E)`
    and global align weights with shape :math:`(S, N, N)` or `(*, N, N)`
    """
    d_k = query.size(-1)

    # 3d tensor matrix multiplication in torch
    # The last two dimension will function as normal
    # matrix multiplication while other dimensions will
    # act as element-wise action.
    # In order to correctly make the mask action batch-wised, 
    # First is to turn all qkv into batch_first.
    # After we turned qkv into (N, S/T, E),
    # we multiply (N, S, E) and (N, E, T) will get (N, S, T),
    # which means for all N batches, the T key scores(last dim) for all S queries(second dim)

    # 1) permute batch dim to the first place (S/T, N_HEAD, N, E) to (N, N_HEAD, S/T, E).
    query = query.transpose(0, -2)
    key = key.transpose(0, -2)
    value = value.transpose(0, -2)

    # 2) Use batched matrix multiplication to get a score tensor(N, N_HEAD, S, T)
    scores = query @ key.transpose(-2, -1) / math.sqrt(d_k)

    # 3) Use mask to set some scores to -inf and softmax
    # For key_padding_mask, set score to -infinity will make its weight being zero
    # As key_padding_mask shape:(N, 1, 1, T), directly using masked_fill wil broadcast
    # it to (N, N_HEAD, S, T) and all scores will be masked correctly
    if key_padding_mask is not None:
        scores = scores.masked_fill(key_padding_mask, -1e9)

    # For mask, the only thing need to do is to added it onto scores directly so it
    # will be broadcast to (N, N_HEAD, S, T), applying same sequential mask to all
    # batches equally.
    if mask is not None:
        scores += mask
    weights = F.softmax(scores, dim=-1)

    if dropout is not None:
        weights = dropout(weights)
    
    # 4)Compute all weighted values and transpose
    # Final attention shape:(N, N_HEAD, T, E), transpose() is needed.
    return (weights @ value).transpose(0, -2), weights.transpose(0, -2)

#### 多头Attention
多头Attention其实只需要把最后一个维度d_model reshape到(n_head, d_k)，代入attention函数就可以自然的把几个不同的头看成是互不相干的几个部分，最后再reshape回去即可。

<img src="https://image.panwenbo.icu/blog20210603212610.png" alt="截屏2021-06-03 下午9.26.02" style="zoom:10%;" />

In [6]:
class MultiheadAttention(nn.Module):
    """MultiHead Attention module"""
    def __init__(self, d_model, nhead, dropout):
        """MultiHead Attention module"""
        super(MultiheadAttention, self).__init__()

        # All nhead features will be combined back to one feature with d_model dim
        assert d_model % nhead == 0
        self.d_k = d_model // nhead
        self.d_model = d_model
        self.nhead = nhead
        self.att_weights = None

        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.dropout = dropout

    def forward(self, query, key, value, mask=None, key_padding_mask=None):
        """Forward propagation with multihead attention
        
        :param query: Q tensor :math:`(L, N, E)`
        :param key: K tensor :math:`(L, N, E)`
        :param value: V tensor :math:`(L, N, E)`
        :param mask: Mask of QKV tensor :math:`(L, N, E)`, places where
        mask value is zero will not be computed, defaults to None
        """
        if mask is not None:
            # Expand mask shape from (L, L) to (1, 1, L, L) for the convenience to broadcast
            # to (N, N_HEAD, L, L) to apply on all heads equally
            mask = mask.unsqueeze(0).unsqueeze(0)
        
        if key_padding_mask is not None:
            # Expand padding mask shape from (N, L) to (N, 1, 1, L) in order to euqally mask 
            # different padding tokens for the correspnding batch.
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)
        n_qlen, n_batch, d_model = query.shape
        n_klen = key.size(0)

        # Compute nhead features from q, k and v.
        # Transpose them into shape (L, N_HEAD, N, K) for convenience of parallelization
        query = self.q_linear(query).view(n_qlen, n_batch, self.nhead, self.d_k).transpose(1, 2)
        key = self.k_linear(key).view(n_klen, n_batch, self.nhead, self.d_k).transpose(1, 2)
        value = self.v_linear(value).view(n_klen, n_batch, self.nhead, self.d_k).transpose(1, 2)

        # Do attention calculation and concatenate
        att_value, self.att_weights = attention(query, key, value, mask, key_padding_mask, self.dropout)
        att_value = att_value.transpose(1, 2).contiguous().view(n_qlen, n_batch, d_model)

        return self.out(att_value)

In [7]:
# 简单测试一下
NHEAD = 8
ma = MultiheadAttention(D_MODEL, NHEAD, nn.Dropout(0.1))
q = k = v = torch.rand(LENGTH, BATCH_SIZE, D_MODEL)
mask = (1 - torch.triu(torch.ones(LENGTH, LENGTH)).T) * -1e9
key_padding_mask = torch.zeros((BATCH_SIZE, LENGTH))
key_padding_mask[4, 2:] = 1 # 遮住第5个样本的最后3个key
key_padding_mask = key_padding_mask == 1

attn = ma(q, k, v, mask, key_padding_mask)
print("Attn shape: ", attn.shape)
print("Weight shape: ", ma.att_weights.shape)
print("Mask: \n", mask == -1e9)
print("First batch: \n", ma.att_weights[:, 0, 0, :])
print("Fifth bath: \n", ma.att_weights[:, 0, 4, :])

Attn shape:  torch.Size([5, 128, 512])
Weight shape:  torch.Size([5, 8, 128, 5])
Mask: 
 tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])
First batch: 
 tensor([[1.1111, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5837, 0.5275, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3552, 0.3709, 0.0000, 0.0000],
        [0.2930, 0.2671, 0.2889, 0.0000, 0.0000],
        [0.2273, 0.2204, 0.2247, 0.2184, 0.2203]], grad_fn=<SliceBackward>)
Fifth bath: 
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5645, 0.5466, 0.0000, 0.0000, 0.0000],
        [0.5623, 0.5488, 0.0000, 0.0000, 0.0000],
        [0.5512, 0.5599, 0.0000, 0.0000, 0.0000],
        [0.5705, 0.5406, 0.0000, 0.0000, 0.0000]], grad_fn=<SliceBackward>)


## 2. 组装Transformer
我们需要先把两个子层组装成Encoder层和Decoder层（decoder是三个层），然后再把这个单层连续叠加个几层就构成了一个Encoder或者Decoder，最后再把这两个拼接再一起就是一个Pytorch中的Transformer。之所以是Pytorch中的Transformer是因为Transformer还包括了把字符编码成嵌入向量和最后的Softmax输出的过程。但是Pytorch实现的Transformer只包括Encoder和Decoder。

### 2.1 TransformerEncoderLayer（图中左侧）
把Multihead Attention子层和FFN用ADD&NORM包裹以后连在一起就是Transformer中的EncoderLayer

In [8]:
class TransformerEncoderLayer(nn.Module):
    """One encoder layer in the transformer"""
    def __init__(self, 
                 d_model, 
                 nhead, 
                 dim_feedforward=2048, 
                 dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = SublayerWrapper(
            MultiheadAttention(d_model, nhead, nn.Dropout(dropout)),
            d_model=d_model,
            dropout_r=dropout)

        self.ffn = SublayerWrapper(
            FeedForwardNet(d_model, dim_feedforward, dropout),
            d_model=d_model,
            dropout_r=dropout)
    
    # src_mask will be directly added onto attention scores.
    # src_key_padding_mask is a ByteTensor places where True located will be masked.
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        attn = self.self_attn(src, src, src, 
                              mask=src_mask, 
                              key_padding_mask=src_key_padding_mask)
        return self.ffn(attn)


### 2.2 TransformerEncoder
只需要把多个TransformerEncoderLayer首尾相连就构成了TransformerEncoder。 注意这里如果需要复制层需要使用`copy.deepcopy()`，并记得要使用`nn.ModuleList`盛放模块使其能被torch识别

In [9]:
import copy

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.norm = norm
    
    def forward(self, src, mask=None, src_key_padding_mask=None):
        for layer in self.layers:
            src = layer(src, mask, src_key_padding_mask)
        
        if self.norm is not None:
            src = self.norm(src)
        return src

In [10]:
# 简单测试一下
encoder_layer = TransformerEncoderLayer(D_MODEL, NHEAD)
encoder = TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(LENGTH, BATCH_SIZE, D_MODEL)
print("Memory shape: ", encoder(src).shape)
print(encoder)

Memory shape:  torch.Size([5, 128, 512])
TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): SublayerWrapper(
        (dropout): Dropout(p=0.1, inplace=False)
        (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (sub_layer): MultiheadAttention(
          (q_linear): Linear(in_features=512, out_features=512, bias=False)
          (v_linear): Linear(in_features=512, out_features=512, bias=False)
          (k_linear): Linear(in_features=512, out_features=512, bias=False)
          (out): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (ffn): SublayerWrapper(
        (dropout): Dropout(p=0.1, inplace=False)
        (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (sub_layer): FeedForwardNet(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_feat

### 2.3 TransformerDecoderLayer（图中右侧）
Decoder层的结构有三个子层：一个Self-Attention层；一个SourceAttention层，这个层使用上一个自注意力层的输出作为Q的来源，但是把Encoder的最后一层的输出作为K和V的来源；当然还有最后的FFN。因此DecoderLayer前向传播的时候还会需要输入Encoder的最后一层的输出作为Memory。

In [11]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, 
                 dim_feedforward=2048, 
                 dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.src_attn = SublayerWrapper(
            MultiheadAttention(d_model, nhead, nn.Dropout(dropout)),
            d_model=d_model,
            dropout_r=dropout)

        self.self_attn = SublayerWrapper(
            MultiheadAttention(d_model, nhead, nn.Dropout(dropout)),
            d_model=d_model,
            dropout_r=dropout)

        self.ffn = SublayerWrapper(
            FeedForwardNet(d_model, dim_feedforward, dropout),
            d_model=d_model,
            dropout_r=dropout)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 
                tgt_key_padding_mask=None, 
                memory_key_padding_mask=None):
        tgt = self.self_attn(tgt, tgt, tgt, 
                             mask=tgt_mask, 
                             key_padding_mask=tgt_key_padding_mask)

        tgt = self.src_attn(tgt, memory, memory, 
                            mask=memory_mask, 
                            key_padding_mask=memory_key_padding_mask)

        return self.ffn(tgt)

### 2.4 TransformerDecoder
Decoder和Encoder在整体结构上基本上没有区别，唯一需要注意的就是由于memory和target序列都需要参与Attention计算，`memory_mask` 和 `tgt_mask`都是同一种mask。

In [12]:
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
        self.norm = norm
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 
                tgt_key_padding_mask=None, 
                memory_key_padding_mask=None):
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask,
                        memory_key_padding_mask=memory_key_padding_mask)
            
            if self.norm is not None:
                tgt = self.norm(tgt)
            return tgt


In [13]:
# 简单测试一下
TGT_LEN = 4
SRC_LEN = 5

decoder_layer = TransformerDecoderLayer(D_MODEL, NHEAD)
decoder = TransformerEncoder(decoder_layer, num_layers=2)
tgt = torch.rand(TGT_LEN, BATCH_SIZE, D_MODEL)
src = torch.rand(SRC_LEN, BATCH_SIZE, D_MODEL)
memory = encoder(src)

print("Memory shape: ", decoder(tgt, memory).shape)
print(encoder)

TypeError: __init__() got an unexpected keyword argument 'dropout'

### 2.5 Transformer
我可以把Encoder和Decoder结构连在一起就得到最后的Transformer。

In [None]:
class Transformer(nn.Module):
    """Transformer Module"""
    def __init__(self, d_model=512, nhead=8, 
                 num_encoder_layers=6, num_decoder_layers=6, 
                 dim_feedforward=2048, dropout=0.1):
        super(Transformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, 
                memory_mask=None, src_key_padding_mask=None, 
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        memory = self.encoder(
            src, 
            mask=src_mask, 
            src_key_padding_mask=src_key_padding_mask
            )
        output = self.decoder.forward(
            tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
            )
        return output

## 3. 输入和输出的处理
虽然有了Transformer结构，但是现在的Transformer结构依然只能接受`d_model`维输入，产生`d_model`维的输出，因此我们还需要设计一个词嵌入结构和预测输出的结构。

### 3.1 位置编码
我们除了直接使用词嵌入方法编码输入以外，还需要把词向量加上一个同样大小的位置向量作为最终输入Encoder的词向量。位置向量公式如下：
$$
P_{(pos,2i)}=\sin\left(\frac{pos}{10000^{2i/d}}\right)\\\\
P_{(pos,2i+1)}=\cos\left(\frac{pos}{10000^{2i/d}}\right)
$$
其中`pos`指的是单词在输入序列中的位置，而`i`表示这个值在embdding向量中的位置。下面的实现中的`den`就是$\frac{1}{10000^{2i/d}}$，这样以来Transformer就可以区分出不同位置上的词语了。

In [None]:
class PositionalEncoding(nn.Module):
    """Add positional encoding to token embedding vector"""
    def __init__(self, embedding_dim, dropout_r, max_len=5000):
        """Add positional encoding to token embedding vector

        :param embedding_dim: The embedding dim of positional encoding
        :param dropout_r: Ratio of combined embedding vector dropout
        :param max_len: Max length positinal encoding can be generated, defaults to 5000
        """
        super(PositionalEncoding, self).__init__()
        den = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000)) / embedding_dim)
        pos = torch.arange(0, max_len).reshape(max_len, 1)
        pos_embedding = torch.zeros((max_len, embedding_dim))
        pos_embedding[:, 0::2] = torch.sin(den * pos)
        pos_embedding[:, 1::2] = torch.cos(den * pos)
        pos_embedding = pos_embedding.unsqueeze(-2) # Reshape into [max_len, 1, emb_size]

        self.dropout = nn.Dropout(dropout_r)

        # Register some variables as buffers when
        # 1) The values don't join compute graph and don't require gradients
        # 2) Wish model.save() could save the variables
        # 3) Wish model.to(device) could apply on the variables
        self.register_buffer('pos_embedding', pos_embedding)
    
    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

In [None]:
class TokenEmbedding(nn.Module):
    """Word embedding but scaled"""
    def __init__(self, vocab_size, embedding_dim):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding_dim = embedding_dim
    
    def forward(self, tokens):
        return self.embedding(tokens) * math.sqrt(self.embedding_dim)

In [None]:
# 简单测试一下
VOCAB_SIZE = 10

seq = torch.ones((LENGTH, BATCH_SIZE)).long()
embedding = TokenEmbedding(VOCAB_SIZE, D_MODEL)
pos_embedding = PositionalEncoding(D_MODEL, 0.1)

embedded = pos_embedding(embedding(seq))
print("Embedded shape: ", embedded.shape)

Embedded shape:  torch.Size([5, 128, 512])


### 3.2 SoftMax输出
最后的最后，还需要一个简简单单的SoftMax输出

In [None]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.out = nn.Linear(d_model, vocab_size)
        self.softmax = nn.LogSoftmax(-1)
    
    def forward(self, input):
        return self.softmax(self.out(input))

## 4. 合体！
有了每一个组件，召唤黑暗大法师（雾）的过程就简单了不少。我们把最终的Transformer类叫做Seq2seqTransformer。

<img src="https://image.panwenbo.icu/blog20210728130636.jpg" alt="5a6fa4717e79dac2b2cb8ec4e5246423" style="zoom:75%;" />

In [None]:
class Seq2seqTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size,
                 num_encoder_layers=6, num_decoder_layers=6,
                 d_model=512, num_heads=8, dim_feedforward=2048,
                 dropout=0.1):
        super(Seq2seqTransformer, self).__init__()
        self.src_embedding = TokenEmbedding(src_vocab_size, d_model)
        self.tgt_embedding = TokenEmbedding(src_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        
        self.transformer = Transformer(
            d_model=d_model, nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout)
        
        self.generator = Generator(d_model, tgt_vocab_size)
    
    def forward(self, src, tgt, src_mask, tgt_mask,
                src_padding_mask, tgt_padding_mask):
        src_embedded = self.positional_encoding(self.src_embedding(src))
        tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt))

        output = self.transformer.forward(
            src=src_embedded, tgt=tgt_embedded,
            src_mask=src_mask, tgt_mask=tgt_mask,
            memory_mask=None,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask)
        
        return self.generator(output)

In [None]:
# 简单测试一下
transformer = Seq2seqTransformer(10, 10)
src = torch.randint(0, 10, (SRC_LEN, BATCH_SIZE)).long()
tgt = torch.randint(0, 10, (TGT_LEN, BATCH_SIZE)).long()
src_mask = torch.zeros((SRC_LEN, SRC_LEN))
tgt_mask = (1 - torch.triu(torch.ones(TGT_LEN, TGT_LEN)).T) * -1e9
src_padding_mask = torch.zeros((BATCH_SIZE, SRC_LEN))
src_padding_mask[4, 2:] = 1 # 遮住第5个样本的最后3个key
src_padding_mask = src_padding_mask == 1
tgt_padding_mask = torch.zeros((BATCH_SIZE, TGT_LEN))

output = transformer(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
print("Output shape: ", output.shape)

Output shape:  torch.Size([4, 128, 10])
