# Transformer

这是一个encoder-decoder的结构，完全脱离CNN，RNN,内部只采用了注意力机制的一种特殊的网络结构。

## 一、导入需要的包

In [1]:
import torch
import math
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# 定义自己的设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 二、定义自己的测试数据集

我们自定义一个简短的NLP数据集：两句German到English的翻译数据。旨在利用这两句数据，训练集一个简单的transformer实现的翻译任务。

In [2]:
sentences = [
    # enc_input                dec_input           dec_output
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# 建立德语和英语单词词库
# Padding 应该为0
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}

# 建立一个下标号到单词的map
src_idx2word = {i: w for i, w in enumerate(src_vocab)}
tgt_idx2word = {i: w for i, w in enumerate(tgt_vocab)}

# map的长度
src_vocab_len = len(src_vocab)
tgt_vocab_len = len(tgt_vocab)

# encoder 和 decoder的最大的序列长度
src_len = 5
tgt_len = 6

In [3]:
class MyDataLoader(Data.Dataset):
	"""Custom Dataset"""
	def __init__(self, enc_inputs, dec_inputs, dec_outputs):
		super(MyDataLoader, self).__init__()
		self.enc_inputs = enc_inputs
		self.dec_inputs = dec_inputs
		self.dec_outputs = dec_outputs
	
	def __len__(self):
		return self.enc_inputs.shape[0]

	def __getitem__(self, index):
		return self.enc_inputs[index], self.dec_inputs[index], self.dec_outputs[index]
	
def word2num(sentences):
	"""将单词序列转换成数字序列"""
	enc_inputs, dec_inputs, dec_outputs = [], [], []
	for i in range(len(sentences)):
		enc_input = [src_vocab[n] for n in sentences[i][0].split()]
		# print('enc_input: ' + str(enc_input))
		dec_input = [tgt_vocab[n] for n in sentences[i][1].split()]
		dec_output = [tgt_vocab[n] for n in sentences[i][2].split()]

		enc_inputs.append(enc_input)
		dec_inputs.append(dec_input)
		dec_outputs.append(dec_output)
	return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = word2num(sentences)
# print(enc_inputs)
loader = Data.DataLoader(MyDataLoader(enc_inputs, dec_inputs, dec_outputs), 2, True)

## 三、构建Transformer网络结构

In [4]:
# Transformer Parameters
d_model = 512	# token embedding和position编码的维度
d_ff = 2048 	# FeedForward dimension (两次线性层中的隐藏层 512->2048->512，线性层是用来做特征提取的），当然最后会再接一个projection层
d_k = d_v = 64  # dimension of K(=Q), V（Q和K的维度需要相同，这里为了方便让K=V）
n_layers = 6  # number of Encoder of Decoder Layer（Block的个数）
n_heads = 8  # number of heads in Multi-Head Attention（有几套头）

### Transformer-Encoder 部分代码实现

In [9]:
"""
	为什么需要 PositionEncoding:
		Transformer 输入的句子的词语都是同时输入，这样是没有位置先后关系的。
		因此，Transformer是需要额外的处理来告知每个词语的相对位置的。其中的
		一个解决方案，就是论文中提到的Positional Encoding，将能表示位置信
		息的编码添加到输入中，让网络知道每个词的位置和顺序。
"""
class PositionEncoding(nn.Module):
	def __init__(self, d_model, dropout=0.1, max_len=5000):
		super(PositionEncoding, self).__init__()
		self.dropout = nn.Dropout(p=dropout)

		pe = torch.zeros(max_len, d_model)
		position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
		div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))	
		pe[:, 0::2] = torch.sin(position * div_term)
		pe[:, 1::2] = torch.cos(position * div_term)
		pe = pe.unsqueeze(0).transpose(0,1)
		self.register_buffer('pe', pe)
	
	def forward(self, x):
		"""
		x: [seq_len, batch_size, d_model]
		"""
		x = x + self.pe[:x.size(0), :]
		return self.dropout(x)
	
def get_attn_pad_mask(seq_q, seq_k):
	"""
		pad_mask的作用：在对value向量加权平均的时候，可以让pad对应的alpha_ij=0，这样注意力就不会考虑到pad向量
		encoder和decoder都可能调用这个函数，所以seq_len视情况而定
		param seq_q: [batch_size, seq_len]
		param seq_k: [batch_size, seq_len]
	"""
	batch_size, len_q = seq_q.size()
	batch_size, len_k = seq_k.size()
	pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], True is masked
	return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]

class ScaledDotProductAttention(nn.Module):
	def __init__(self):
		super(ScaledDotProductAttention, self).__init__()

	def forward(self, Q, K, V, attn_mask):
		"""
		param Q: [batch_size, n_heads, len_q, d_k]
        param K: [batch_size, n_heads, len_k, d_k]
        param V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        说明：在encoder-decoder的Attention层中len_q(q1,..qt)和len_k(k1,...km)可能不同
		"""
		scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
		# mask 矩阵填充scores 
		scores.masked_fill_(attn_mask, -1e9)

		attn = nn.Softmax(dim=-1)(scores)
		context = torch.matmul(attn, V)
		return context, attn

class MultiHeadAttention(nn.Module):
	"""
		多头注意力机制
	"""
	def __init__(self):
		super(MultiHeadAttention, self).__init__()
		self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
		self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
		self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
		self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

	def forward(self, input_Q, input_K, input_V, attn_mask):
		"""
		param input_Q: [batch_size, len_q, d_model]
		param input_K: [batch_size, len_k, d_model]
		param input_V: [batch_size, len_v, d_model]
		param attn_mask: [batch_size, seq_len, seq_len]
		"""
		residual, batch_size = input_Q, input_Q.size(0)
		# 下面的多头的参数矩阵是放在一起做线性变换的， 然后在拆成多个头
		# B: batch_size, S: seq_len, D: dim
		# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, Head, W) -trans-> (B, Head, S, W)
		#  			线性变换				拆成多头					

		# Q: [batch_size, n_heads, len_q, d_k]
		Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
		# K: [batch_size, n_heads, len_k, d_k]
		K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
		# V: [batch_size, n_heads, len_v(=len_k), d_v]
		V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)

		# 因为是多头， 所以需要mask矩阵扩充为4维
		# attn_mask: [batch_size, seq_len, seq_len] -> [batch_size, n_heads, seq_len, seq_len]
		attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

		# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
		context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
		# Concat: 下面是将不同头的输出向量拼接在一起
		# context: [batch_size, n_heads, len_q, d_v] -> [batch_size, len_q, n_heads * d_v]
		context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)
		# Linear: 再做一个projection
		output = self.fc(context) # [batch_size, len_q, d_model]
		return nn.LayerNorm(d_model).to(device)(output + residual), attn

class PoswiseFeedForwardNet(nn.Module):
	def __init__(self):
		super(PoswiseFeedForwardNet, self).__init__()
		self.fc = nn.Sequential(
			nn.Linear(d_model, d_ff, bias=False),
			nn.ReLU(),
			nn.Linear(d_ff, d_model, bias=False)
		)

	def forward(self, inputs):
		"""
		inputs: [batch_size, seq_len, d_model]
		"""
		residual = inputs
		output = self.fc(inputs)
		return nn.LayerNorm(d_model).to(device)(output + residual)

class EncoderLayer(nn.Module):
	def __init__(self):
		super(EncoderLayer, self).__init__()
		self.enc_self_attn = MultiHeadAttention()
		self.pos_ffn = PoswiseFeedForwardNet()

	def forward(self, enc_inputs, enc_self_attn_mask):
		"""
		param enc_inputs: [batch_size, src_len, d_model]
		param enc_self_attn_mask: [batch_size, src_len, src_len] mask 矩阵
		"""
		# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
		# 第一个enc_inputs * W_Q = Q
		# 第二个enc_inputs * W_K = K
		# 第三个enc_inputs * W_V = V
		enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q, K, V
		enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
		return enc_outputs, attn

class Encoder(nn.Module):
	def __init__(self):
		super(Encoder, self).__init__()
		self.src_embedding = nn.Embedding(src_vocab_len, d_model)
		self.pos_embedding = PositionEncoding(d_model) #位置编码
		self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
	
	def forward(self, enc_inputs):
		"""
		param enc_inputs: [batch_size, src_len]
		"""
		enc_outputs = self.src_embedding(enc_inputs) # [batch_size, src_len, d_model]
		enc_outputs = self.pos_embedding(enc_outputs.transpose(0,1)).transpose(0,1) # [batch_size, src_len, src_len]
		# Encoder输入序列得pad mask矩阵
		enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
		enc_self_attns = [] # 用于保存返回的 attention 的值
		for layer in self.layers:
			# 上一个block的输出enc_outputs作为当前block的输入
			enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
			enc_self_attns.append(enc_self_attns)
		return enc_outputs, enc_self_attns

### Transformer-Decoder 部分代码实现

In [11]:
def get_attn_subsequence_mask(seq):
    """建议打印出来看看是什么的输出（一目了然）
    seq: [batch_size, tgt_len]
    """
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    # attn_shape: [batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成一个上三角矩阵
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

class DecoderLayer(nn.Module):
	def __init__(self):
		super(DecoderLayer, self).__init__()
		self.dec_self_attn = MultiHeadAttention()
		self.dec_enc_attn = MultiHeadAttention()
		self.pos_ffn = PoswiseFeedForwardNet()

	def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
		# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
		dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
		# dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
		dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
		dec_outputs = self.pos_ffn(dec_outputs)  # [batch_size, tgt_len, d_model]
		return dec_outputs, dec_self_attn, dec_enc_attn


class Decoder(nn.Module):
	def __init__(self):
		super(Decoder, self).__init__()
		self.tgt_embedding = nn.Embedding(tgt_vocab_len, d_model)
		self.pos_embedding = PositionEncoding(d_model)
		self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

	def forward(self, dec_inputs, enc_inputs, enc_outputs):
		"""
		param dec_inputs: [batch_size, tgt_len]
		param enc_inputs: [batch_size, src_len]
		param enc_outputs: [batch_size, src_len, d_model]
		"""
		dec_outputs = self.tgt_embedding(dec_inputs) # [batch_size, tgt_len, d_model]
		dec_outputs = self.pos_embedding(dec_outputs.transpose(0,1)).transpose(0,1).to(device)
		# Docoder 输入序列的pad mask 矩阵
		dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).to(device) # [batch_size, tgt_len, tgt_len]
		dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).to(device) # [batch_size, tgt_len, tgt_len]

		# Docoder 中把两种mask矩阵相加(既屏蔽了pad的信息， 也屏蔽了未来时刻的信息)
		dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask),
                                      0).to(device)  # [batch_size, tgt_len, tgt_len]; torch.gt比较两个矩阵的元素，大于则返回1，否则返回
		dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)

		dec_self_attns, dec_enc_attns = [], []
		for layer in self.layers:
			# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
			# Decoder的Block是上一个Block的输出dec_outputs（变化）和Encoder网络的输出enc_outputs（固定）
			dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
                                                             dec_enc_attn_mask)
			dec_self_attns.append(dec_self_attn)
			dec_enc_attns.append(dec_enc_attn)
        # dec_outputs: [batch_size, tgt_len, d_model]
		return dec_outputs, dec_self_attns, dec_enc_attns


### Transformer网络整体架构

In [7]:
class Transformer(nn.Module):
	"""Transformer 架构"""
	def __init__(self):
		super(Transformer, self).__init__()
		self.encoder = Encoder().to(device)
		self.decoder = Decoder().to(device)
		self.projection = nn.Linear(d_model, tgt_vocab_len, bias=False).to(device)

	def forward(self, enc_inputs, dec_inputs):
		"""
		param enc_inputs: [batch_size, src_len]
		param dec_inputs: [batch_size, tgt_len]
		"""
		enc_outputs, enc_self_attns = self.encoder(enc_inputs)
		dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
		dec_logits = self.projection(dec_outputs)
		return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

## 四、训练

In [14]:
epochs = 100
model = Transformer().to(device)
# 这里的损失函数里面设置了一个参数 ignore_index=0，因为 "pad" 这个单词的索引为 0，这样设置以后，就不会计算 "pad" 的损失（因为本来 "pad" 也没有意义，不需要计算）
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)  # 用adam的话效果不好

# ====================================================================================================
for epoch in range(epochs):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        """
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        dec_outputs: [batch_size, tgt_len]
        """
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs, dec_outputs.view(-1))  # dec_outputs.view(-1):[batch_size * tgt_len * tgt_vocab_size]
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch: 0001 loss = 2.574859
Epoch: 0002 loss = 2.292260
Epoch: 0003 loss = 2.148703
Epoch: 0004 loss = 1.836867
Epoch: 0005 loss = 1.578821
Epoch: 0006 loss = 1.387943
Epoch: 0007 loss = 1.219568
Epoch: 0008 loss = 1.083120
Epoch: 0009 loss = 0.832052
Epoch: 0010 loss = 0.695549
Epoch: 0011 loss = 0.525437
Epoch: 0012 loss = 0.406661
Epoch: 0013 loss = 0.316452
Epoch: 0014 loss = 0.250296
Epoch: 0015 loss = 0.225005
Epoch: 0016 loss = 0.203066
Epoch: 0017 loss = 0.169627
Epoch: 0018 loss = 0.146839
Epoch: 0019 loss = 0.097101
Epoch: 0020 loss = 0.089779
Epoch: 0021 loss = 0.068830
Epoch: 0022 loss = 0.081064
Epoch: 0023 loss = 0.049723
Epoch: 0024 loss = 0.046363
Epoch: 0025 loss = 0.030551
Epoch: 0026 loss = 0.041131
Epoch: 0027 loss = 0.033184
Epoch: 0028 loss = 0.028761
Epoch: 0029 loss = 0.031112
Epoch: 0030 loss = 0.021891
Epoch: 0031 loss = 0.022155
Epoch: 0032 loss = 0.011714
Epoch: 0033 loss = 0.012132
Epoch: 0034 loss = 0.011692
Epoch: 0035 loss = 0.012515
Epoch: 0036 loss = 0

## 预测

In [16]:
def greedy_decoder(model, enc_input, start_symbol):
    """贪心编码
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:
        # 预测阶段：dec_input序列会一点点变长（每次添加一个新预测出来的单词）
        dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                              -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        # 增量更新（我们希望重复单词预测结果是一样的）
        # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
        next_word = prob.data[-1]  # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
        next_symbol = next_word
        if next_symbol == tgt_vocab["E"]:
            terminal = True

    greedy_dec_predict = dec_input[:, 1:]
    return greedy_dec_predict


# ==========================================================================================
# 预测阶段
enc_inputs, _, _ = next(iter(loader))
for i in range(len(enc_inputs)):
    greedy_dec_predict = greedy_decoder(model, enc_inputs[i].view(1, -1).to(device), start_symbol=tgt_vocab["S"])
    print(enc_inputs[i], '->', greedy_dec_predict.squeeze())
    print([src_idx2word[t.item()] for t in enc_inputs[i]], '->',
          [tgt_idx2word[n.item()] for n in greedy_dec_predict.squeeze()])

tensor([1, 2, 3, 5, 0]) -> tensor([1, 2, 3, 5, 8])
['ich', 'mochte', 'ein', 'cola', 'P'] -> ['i', 'want', 'a', 'coke', '.']
tensor([1, 2, 3, 4, 0]) -> tensor([1, 2, 3, 4, 8])
['ich', 'mochte', 'ein', 'bier', 'P'] -> ['i', 'want', 'a', 'beer', '.']
