In [2]:
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

reference:

http://nlp.seas.harvard.edu/annotated-transformer/


### Transformer整体结构

Transformer整体上是encoder-decoder架构。

Encoder把**符号表示序列**映射为**连续表示序列**， $\left(x_{1}, \ldots, x_{n}\right) \to \mathbf{z}=\left(z_{1}, \ldots, z_{n}\right)$

Decoder为每个$\mathbf{z}$中的元素生成输出序列$\left(y_{1}, \ldots, y_{m}\right)$。
在生成下一个符号时，使用先前生成的符号作为额外的输入。

下面代码段展示了encoder-decoder的框架：

In [6]:
class EncoderDecoder(nn.Module):
    '''
    标准的encoder-decoder结构
    '''
    def __ini__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        '''
        src_embed:
        tgt_embed:
        '''
        self.encoder = encoder
        self.decoder = decoder
        
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        
        # decoder中的memory为encode的输出
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
        

In [8]:
class Generator(nn.Module):
    "定义标准的 linear + softmax 生成方式"
    
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)
        

Transformer也采用如上decoder-encoder结构。
在decoder和encoder中分别使用了堆叠的self-attention和point-wise全连接层。

![image](./images/169628874-e9586707-02cc-439b-a0a2-7b5202d16c38.png)

In [3]:
def clones(module, N):
    "初始化N个完全一样的module，但注意每个module是参数不共享的"
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

关于ModuleList可以参考： https://zhuanlan.zhihu.com/p/64990232

TODO:

- [] 可以学习一下copy和deepcopy的区别


In [4]:
class Encoder(nn.Module):
    "Encoder主要组件就是N layers的堆叠，用到clones函数"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        "迭代处理x及mask"
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

对应上面的框架图，Encoder中的每个layer包括：MultiHead Attention，残差连接+LayerNorm， Feed Forward Network，残差连接+LayerNorm。

N层结束后，最后还有一个LayerNorm。

每个layer中的残差连接+LayerNorm可以表示为：LayerNorm(x + Sublayer(x))。

先看LayerNorm:

In [5]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

再看下残差链接：

In [6]:
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

再回头看框架图，encoder中每层（共N层）中有两处sublayer，一个是MultiHeadAttention,一个是position-wise 全连接层。
下面我们就按照encoder框架图实现encoderlayer。

这里再放一下框架图，省得翻回去看。

![image](./images/169628874-e9586707-02cc-439b-a0a2-7b5202d16c38.png)

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size
        
    def forward(self, x, mask):
        
        # 为什么self_attn用lambda的形式传参？
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        
        return self.sublayer[1](x, self.feed_forward)

### 下面看Decoder

有了Encoder基础，Decoder就会容易上手一些。不过还是值得注意二者输入输出的差别。

In [7]:
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
            
        return self.norm(x)

区别于encoder, decoder有三个sublayer。

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        
        # 这一步的目的是什么？
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)