# 手撕Transformer
1. 实现以下这些组件
   1. Embedding
   2. Positional Encoding
   3. Positionwise Feed-Forward Network
   4. Add & norm
   5. (Masked) Multi-head attention
2. 将这些组件拼起来
   1. encoder
   2. decoder
   3. Bahdanau注意力机制
3. 用小数据集验证一下

In [None]:
import torch
from torch import nn
import math


## Embedding
Pytorch提供了Empedding封装，主要的两个参数是num_embedding和embedding_dim。它们分别表示
1. num_embedding：文本序列tokenlize之后，token的种类数。其中还包含一些特殊的token。
2. embedding_dim：每一个token嵌入后，嵌入向量的维度。

## Positional Encoding
一个文本序列tokenlize和embedding之后得到一个表征矩阵$\mathbf{X} \in \mathbb{R}^{n\times d}$，为了使Transformer架构能够并行的处理文本序列，需要在文本序列中加入位置信息：$\mathbf{X} + \mathbf{P}$，位置信息应该能够满足如下要求：
1. 文本中每一个词的位置编码是唯一的。
2. 文本中的词两两之间的位置距离是固定的。
3. 位置编码结果不受文本长度的影响。
4. 位置编码是具有确定性的。
当前普遍使用的位置编码是：
$$
\begin{aligned}
p_{i, 2 j} & =\sin \left(\frac{i}{10000^{2 j / d}}\right) \\
p_{i, 2 j+1} & =\cos \left(\frac{i}{10000^{2 j / d}}\right)
\end{aligned}
$$
其中，$p_{i,2j}$和$p_{i, 2j+1}$分别为矩阵$\mathbf{P} \in \mathbb{R}^{n\times d}$中的元素。

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, dropout_p , seq_max_len = 1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        self.posi_enc = torch.zeros((1, seq_max_len, hidden_size))
        tmp = torch.arange(0, hidden_size, 1, dtype=torch.float32).reshape(-1, 1)
        tmp2 = torch.pow(10000, torch.arange(0, hidden_size, 2, dtype=torch.float32)/hidden_size)
        tmp3 = tmp / tmp2
        self.posi_enc[:, :, 0::2] = torch.sin(tmp3)     # 能用pytorch中矩阵操作完成的，就不要自己写for循环
        self.posi_enc[:, :, 1::2] = torch.cos(tmp3)
    def forward(self, X):   # 注意！这里的X的形状和上面说的(n,d)不太一样，这里多了一个batch维度。
        X = X + self.posi_enc[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)  # dropout防止过拟合


## Positionwise FFN
位置前馈神经网络用同一个MLP将文本序列中的每一个embedding转换到另一个维度。

In [None]:
class PositionwiseFFN(nn.Module):
    def __init__(self, input_h_size, ffn_h_size, ffn_output_size):
        super().__init__()
        self.dense1 = nn.Linear(input_h_size, ffn_h_size)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_h_size, ffn_output_size)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))


## （Masked）Multi-head Attention
1. 注意力函数
2. 自注意力
3. Bahdanau注意力机制
4. 多头注意力，包含Mask功能

### 注意力函数——Scaled Dot Product Attention
对于文本序列长度为$m$，有$n$个queries的情况，我们有queries矩阵$\mathbf{Q} \in \mathbb{R}^{n\times d}$、key矩阵$\mathbf{K} \in \mathbb{R}^{m\times d}$、value矩阵$\mathbf{V} \in \mathbb{R}^{m\times v}$，然后通过如下公式计算每个embedding的注意力输出
$$
\rm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d}}\right)\mathbf{V} \in \mathbb{R}^{n\times v}
$$

In [None]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
    def musked_softmax(scores, valid_lens):
        """
        score: [Batch_size, Num of quaries, Num of keys]
        其中Num of keys这个维度是被查询文本的序列长度l，
        valid_len在这个维度上指定有效长度
        """
        def _seq_mask(scores, valid_lens, value=0):
            max_seq_len = scores.size(1)    # num_of_keys
            mask = torch.arange(max_seq_len,
                                dtype=torch.float32, device=scores.device)[None, :] < valid_lens[:, None]
            scores[~mask] = value
            return scores
        
        if valid_lens is None:
            return nn.functional.softmax(scores, dim=-1)
        else:
            qk_shape = scores.shape
            if valid_lens.dim() == 1:
                valid_lens = torch.repeat_interleave(valid_lens, qk_shape[1])
            else:
                valid_lens = valid_lens.reshape(-1)

            scores = _seq_mask(scores.reshape(-1, qk_shape[-1]), valid_lens, value=1e-6)
            return nn.functional.softmax(scores.reshape(qk_shape), dim=-1)

    
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[2]
        alpha = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = self.musked_softmax(alpha, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

这里直接解释一下多头自注意力，不然读者在阅读的时候很容易搞不清楚$\mathbf{q}$、$\mathbf{k}$、$\mathbf{v}$维度的含义以及为什么要再分别乘以一个权重矩阵。在Transformer中，使用的是多头自注意力，所谓的自注意力，就是$\mathbf{q}$、$\mathbf{k}$、$\mathbf{v}$都来自于同一个文本序列中的各token的embedding后的向量通过权重矩阵转换得到；多头指转换的权重矩阵有多组。最后获得的多个头的输出连接在一起由输出权重矩阵$\mathbf{W}_o$映射得到。于是有下列等式
$$
\begin{align}
\mathbf{h}_{i}&=f\left(\mathbf{W}_{i}^{(q)} \mathbf{q}, \mathbf{W}_{i}^{(k)} \mathbf{k}, \mathbf{W}_{i}^{(v)} \mathbf{v}\right) \in \mathbb{R}^{p_{v}} \\

output &= \mathbf{W}_{o}\left[\begin{array}{c}
\mathbf{h}_{1} \\
\vdots \\
\mathbf{h}_{h}
\end{array}\right] \in \mathbb{R}^{p_{o}}
\end{align}
$$

等式(1)中的$\mathbf{q}$、$\mathbf{k}$、$\mathbf{v}$在自注意力中都是embedding vector。

下面的实现方式参考了d2l，通过将映射之后的$\mathbf{q}$、$\mathbf{k}$、$\mathbf{v}$的维度$p_q$、$p_k$、$p_v$设置成$p_q = p_k = p_v = p_o/h$，将多个头的qkv分别连在一起进行映射，使h个头能够并行的进行计算。下面的实现中num_hiddens就是$p_o$。

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, attention_fun, num_hiddens, num_heads, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = attention_fun
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
    
    def reshape_qkv(self, X):
        shape = X.shape
        X = X.reshape(shape[0], shape[1], self.num_heads, -1)

        # 将X转换为(batch_size, num_heads, q的数目或者k-v的数目——文本长度, p_q or p_k or p_v)
        X = X.permute(0, 2, 1, 3)
        return X.reshape(-1, shape[2], shape[3])    # return前的这一步相当于把所有的头连接在一起
    
    def reshape_attention_output(self, X):
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)    # (batch_size, q的数目, p_o)

    def forward(self, quaries, keys, values, valid_lens):
        """
        输入的qkv的形状是[batch_size, q的数目或者k-v的数目——文本长度, d_q或者d_k或d_v]
        映射后的qkv的维度是p_q, p_k, p_v
        """
        quaries = self.reshape_qkv(self.W_q(quaries))   # (batch_size * num_heads, q的数目, p_q)
        keys = self.reshape_qkv(self.W_k(keys))
        values = self.reshape_qkv(self.W_v(values))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
            output = self.attention(quaries, keys, values, valid_lens)  # output.shape = (batch_size * num_heads, q的数目, p_o/h)
            output_concat_heads = self.reshape_attention_output(output)
            return self.W_o(output_concat_heads)
            


## Add & Norm
Transformer中使用了Layer normalization，这是因为自注意力在每个sequence内部求的注意力，希望每个token对于sequence内的注意力之间满足均值为0、方差为1。

In [None]:
class AddNorm(nn.Module):
    def __init__(self, embedding_dim, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embedding_dim)

    def forward(self, X, Y):
        return self.layer_norm(self.dropout(Y) + X)

## Encoder
现在将上面的模块拼起来，构建一个Transformer encoder。

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_heads, num_hiddens, dropout, ffn_h_size, bias=False):
        super().__init__()
        attention_fun = DotProductAttention(dropout)
        self.attention = MultiHeadAttention(attention_fun, num_hiddens, num_heads, bias)
        self.positionwise_ffn = PositionwiseFFN(num_hiddens, ffn_h_size, num_hiddens)
        self.add_norm1 = AddNorm(num_hiddens, dropout)
        self.add_norm2 = AddNorm(num_hiddens, dropout)
    
    def forward(self, X, valid_len):
        Y = self.add_norm1(X, self.attention(X, X, X, valid_len))
        output = self.add_norm2(Y, self.positionwise_ffn(Y))
        return output

Encoder输入的sequence在embedding后乘以了一个$\sqrt{\rm{embedding\_dim}}$，这样做是因为序列中的tokens的embedding vector中的值是被均值为0.0f，标准差为$\frac{1}{\sqrt{embedding\_dim}}$（0.044 for 512, 0.03125 for 1024）。所以position embedding $\in (-1, 1)$，会使word embedding中的信号值相对太小，因而要乘以$\sqrt{\rm{embedding\_dim}}$（22.6 for 512, 32 for 1024）。

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_heads, num_hiddens, dropout, ffn_h_size,
                 num_blks, vocab_size, pad_idx=None, bias=False):
        super().__init__
        self.num_hiddens = num_hiddens
        self.num_heads = num_heads
        self.embedding = nn.Embedding(vocab_size, num_hiddens, padding_idx=pad_idx)
        self.position_ecd = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module('blok'+str(i), 
                                 TransformerEncoderBlock(num_heads, num_hiddens, 
                                                         dropout, ffn_h_size, bias))
            
    def forward(self, X, valid_lens):
        X = self.position_ecd(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = []
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights.append(blk.attention.attention.attention_weights)
        return X


## Decoder
Transformer的Decoder中的第1个注意力层使用了mask操作，为了使训练的时候target中的当前词只看得到前面的文本序列。第一个注意力层使的qkv均来自上一层decoder的输出。

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, num_heads, num_hiddens,
                 ffn_h_size, dropout, i, bias=False):
        super().__init__()
        self.i = i
        self.num_heads = num_heads
        self.num_hiddens = num_hiddens
        attention_fun = DotProductAttention(dropout)
        self.attention1 = MultiHeadAttention(attention_fun, num_hiddens, num_heads, bias)
        self.add_norm1 = AddNorm(num_hiddens, dropout)
        self.attention2 = MultiHeadAttention(attention_fun, num_hiddens, num_heads, bias)
        self.add_norm2 = AddNorm(num_hiddens, dropout)
        self.positionwise_ffn = PositionwiseFFN(num_hiddens, ffn_h_size, num_hiddens)
        self.add_norm3 = AddNorm(num_hiddens, dropout)
    
    def forward(self, X, context_time_step, enc_output, enc_valid_len):
        """
        X.shape: [batch_size, seq_len_after_padding, embedding_size]
        """
        if context_time_step[self.i] is None:
            keys = X
            values = X
        else:
            keys = torch.cat((context_time_step[self.i], X), dim=1)
            values = torch.cat((context_time_step[self.i], X), dim=1)
        context_time_step = torch.cat((context_time_step[self.i], X), dim=1)

        if self.training:
            batch_size = X.shape[0]
            seq_len_after_padding = X.shape[1]
            dec_valid_len = torch.arange(1, seq_len_after_padding + 1,
                                         device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_len = None
        
        X2 = self.attention1(X, keys, values, dec_valid_len)
        output_of_add_norm1 = self.add_norm1(X, X2)
        output_of_attention2 = self.attention2(output_of_add_norm1, 
                                               enc_output, enc_output, enc_valid_len)
        output_of_add_norm2 = self.add_norm2(output_of_add_norm1,
                                             output_of_attention2)
        output_of_ffn = self.positionwise_ffn(output_of_add_norm2)
        output = self.add_norm3(output_of_add_norm2, output_of_ffn)

        return output, context_time_step

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_heads, num_hiddens,
                 ffn_h_size, dropout, num_blks,
                 vocab_size, pad_idx=None, bias=False):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_blks = num_blks
        self.embedding = nn.Embedding(vocab_size, num_hiddens, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerDecoderBlock(
                num_hiddens, ffn_h_size, num_heads, dropout, i))
        self.dense = nn.LazyLinear(vocab_size)

        def forward(self, X, context_time_step, enc_output, enc_vilid_len):
            input_to_dec = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
            for _, blk in enumerate(self.blks):
                X, context_time_step = blk(X, context_time_step, enc_output, enc_vilid_len)

            return self.dense(X)


## 构建Transformer
用上面模块拼成一个Transformer。

In [None]:
class MyTransformer(nn.Module):
    def __init__(self, num_heads, num_hiddens,
                 ffn_h_size, dropout, num_blks_enc, num_blk_dec,
                 vocab_size_src, vocab_size_tar,
                 pad_idx=None, bias=False):
        super().__init__
        self.pad_idx = pad_idx
        self.encoder = TransformerEncoder(num_heads, num_hiddens, dropout, ffn_h_size,
                                          num_blks_enc, vocab_size_src, pad_idx, bias)
        self.decoder = TransformerDecoder(num_heads, num_hiddens, ffn_h_size, dropout,
                                          num_blk_dec, vocab_size_tar, pad_idx, bias)
        self.output_layer = nn.Linear(num_hiddens, vocab_size_tar)
        
        def forward(self, srcX, tarX, enc_valid_len, contex_time_step):
            enc_output = self.encoder(srcX, enc_valid_len)
            dec_output = self.decoder(tarX, enc_output, contex_time_step)
            output = self.output_layer(dec_output)

            return output
        

## 训练
这里我们使用训练集是[链接](https://zhuanlan.zhihu.com/p/581334630)中介绍的来自[链接](https://github.com/P3n9W31/transformer-pytorch)的小数据集，只有几KB大小，词表量10000左右。

In [4]:
# Config
BATCH_SIZE = 64
LR = 0.001
NUM_HIDDENS = 512
FFN_H_SIZE = 2048
N_LAYERS = 6
NUM_HEADS = 8
DROPOUT_RATE = 0.2
N_EPOCH = 60
PAD_ID = 0
TRAIN_SET_PROP = 0.8

from torch.utils import data
from data_load import load_data, load_cn_vocab, load_en_vocab
from data_my_Transformer import en2cn_dataset
import os, sys
import numpy as np
# os.environ['CUDA_VISIBLE_DEVICES'] = 1

cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
X_train, Y_train, _, _ = load_data('train') # X:cn, Y:en | [seq_num, emb_size_after_padding]
dataset = en2cn_dataset(X_train, Y_train)
train_size = int(len(dataset) * TRAIN_SET_PROP)
test_size = len(dataset) - train_size
train_dataset, test_dataset = data.random_split(dataset, [train_size, test_size])
train_data_loader = data.DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
test_data_loader = data.DataLoader(test_dataset, BATCH_SIZE, shuffle=True)

model = MyTransformer(NUM_HEADS, NUM_HIDDENS, FFN_H_SIZE, DROPOUT_RATE,
                      N_LAYERS, N_LAYERS, len(en2idx), len(cn2idx), PAD_ID).cuda()
optimizer = torch.optim.Adam(model.parameters(), LR)

for epoch_i in range(N_EPOCH):
    for data_batch in train_data_loader:
        en_batch, cn_batch = data_batch
        cn_hat = model(en_batch, cn_batch, )
        



SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
