In [1]:
import torch
from torch import nn

import numpy as np 


import re
import math
import collections

# 加载数据

In [2]:
def get_dict(filtered_pairs):
    "用已知的‘英文-中文’生成字典"
    en_vocab={}
    cn_vocab={}
    en_vocab['<pad>'],en_vocab['<bos>'],en_vocab['<eos>']=0,1,2
    cn_vocab['<pad>'],cn_vocab['<bos>'],cn_vocab['<eos>']=0,1,2
    en_idx, cn_idx = 3, 3
    for en, cn in filtered_pairs:
        for w in en:
            if w not in en_vocab:
                en_vocab[w]=en_idx
                en_idx+=1
        for w in cn:
            if w not in cn_vocab:
                cn_vocab[w]=cn_idx
                cn_idx+=1
    print(f'The size of english vocab is {len(en_vocab)}.')
    print(f'The size of chinese vocab is {len(cn_vocab)}.')
    return en_vocab, cn_vocab

In [3]:
def load_data_translation(file_name='cmn2.txt', batch_size=2, steps = 15, is_train = True, file_range=[0,1]):
    # 打开数据
    print("Loading data...")
    lines = open(file_name,encoding='utf-8').read().strip().split('\n')
    begin = int(len(lines)*file_range[0])
    end   = int(len(lines)*file_range[1])
    
    # 从txt文件中读取出数据，并成对地保存在pairs中
    words_re=re.compile(r'\w+|\,|\.|\?|^[0-9]*$')
    pairs, filtered_pairs = [], []
    for l in lines:
        en_sentence, cn_sentence = l.split('\t',1)
        pairs.append((words_re.findall(en_sentence.lower()), list(cn_sentence)))

    # 筛选出来常用的和简短的词语
    MAX_LEN = steps
    for x in pairs:
        if len(x[0])<MAX_LEN and len(x[1])<MAX_LEN:
            filtered_pairs.append(x) 
    print(f'Total sentence number: {len(filtered_pairs)}') 
    print(f'Total selected sentence number: {end - begin}') 

    # 建立字典
    en_vocab, cn_vocab = get_dict(filtered_pairs)
    vocab = [en_vocab, cn_vocab]

    # 将句子中的单词/字编码
    padded_en_sents, padded_cn_sents, padded_cn_label_sents=[], [], []
    for en,cn in filtered_pairs[begin:end]:
        padded_en_sent=en+["<eos>"]+["<pad>"]*(MAX_LEN-len(en))
        padded_cn_sent=["<bos>"]+cn+["<eos>"]+["<pad>"]*(MAX_LEN-len(cn))
        padded_cn_label_sent=cn+['<eos>']+['<pad>']*(MAX_LEN-len(cn)+1)

        padded_en_sents.append([en_vocab[w] for w in padded_en_sent])
        padded_cn_sents.append([cn_vocab[w] for w in padded_cn_sent])
        padded_cn_label_sents.append([cn_vocab[w] for w in padded_cn_label_sent])

    train_en_sents= torch.tensor(np.array(padded_en_sents))
    train_cn_sents= torch.tensor(np.array(padded_cn_sents))
    train_cn_label_sents=torch.tensor(np.array(padded_cn_label_sents))
    en_valid_len = (train_en_sents != en_vocab['<pad>']).type(torch.int32).sum(1)
    cn_valid_len = (train_cn_sents != cn_vocab['<pad>']).type(torch.int32).sum(1)

    # 转化成dataloader的形式
    data_arrays = [train_en_sents, train_cn_sents, train_cn_label_sents, en_valid_len, cn_valid_len]
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)
    print('Successfully finished!')

    return dataloader, vocab

# 网络架构

In [4]:
def try_gpu():
    """如果存在则返回gpu, 否则返回cpu"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')


<div align = 'center'>
    <img src = 'https://5b0988e595225.cdn.sohucs.com/images/20191001/0e70fe2d169444339bc986f9d2077cbb.jpeg'>
</div>

In [5]:
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError


class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError


class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_X, decoder_X, *args):
        enc_outputs = self.encoder(encoder_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(decoder_X, dec_state)

## 编码器

### 1.1 多头注意力机制

In [6]:
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


def masked_softmax(X, valid_lens):  # X:3D张量，valid_lens:1D或2D张量
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
            # torch.repeat_interleave(input, repeats, dim=None) → Tensor
        else:
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) # 最后一轴上被掩蔽的元素使用非常大的负值替换，从而其softmax输出为0
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [7]:
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状：(batch_size，查询的个数，d)，              keys的形状：(batch_size，“键－值”对的个数，d)
    # values的形状： (batch_size，“键－值”对的个数，值的维度)，  valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)  # 设置transpose_b=True为了交换keys的最后两个维度
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [8]:
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [9]:
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries，keys，values的形状: (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # 经过变换后，输出的queries，keys，values　的形状: (batch_size*num_heads，查询或者“键－值”对的个数，num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:  # valid_lens　的形状: (batch_size，)或(batch_size，查询的个数)
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)  # 在轴0，将第一项（标量或者矢量）复制num_heads次，然后如此复制第二项，然后诸如此类。
        output = self.attention(queries, keys, values, valid_lens)  # output的形状:(batch_size*num_heads，查询的个数，num_hiddens/num_heads)
        output_concat = transpose_output(output, self.num_heads)  # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        return self.W_o(output_concat)

### 1.2 基于位置的

In [10]:
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        # self.relu = nn.ReLU()
        self.relu = nn.LeakyReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

### 1.3 残差连接

In [11]:
class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

### 1.4 位置编码

In [12]:
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        a = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
        b = torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P = torch.zeros((1, max_len, num_hiddens))
        self.P[:, :, 0::2] = torch.sin(a / b)
        self.P[:, :, 1::2] = torch.cos(a / b)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

### 1.5 编码器结构

In [13]:
class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, 
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention  = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1   = AddNorm(norm_shape, dropout)
        self.ffn        = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2   = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):  # 自注意力+残差连接
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))  # 位置前馈+残差连接

In [14]:
class TransformerEncoder(Encoder):
    """Transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, 
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens    = num_hiddens
        self.embedding      = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding   = PositionalEncoding(num_hiddens, dropout)
        self.blks           = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i), EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        embedding = self.embedding(X) * math.sqrt(self.num_hiddens)# 因为位置编码值在-1和1之间，因此嵌入值乘以嵌入维度的平方根进行缩放，然后再与位置编码相加。
        X = self.pos_encoding(embedding) 
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

## 解码器

In [15]:
class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, 
                 num_heads,dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1   = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2   = AddNorm(norm_shape, dropout)
        self.ffn        = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3   = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:  # 训练阶段，输出序列的所有词元都在同一时间处理， 因此state[2][self.i]初始化为None。
            key_values = X
        else:  # 预测阶段，输出序列是通过词元一个接着一个解码的，因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示。
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape  # dec_valid_lens的开头:(batch_size,num_steps), 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y  = self.addnorm1(X, X2)
        
        # 编码器－解码器注意力
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)  # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Z  = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [16]:
class AttentionDecoder(Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError

In [17]:
class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens    = num_hiddens
        self.num_layers     = num_layers
        self.embedding      = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding   = PositionalEncoding(num_hiddens, dropout)
        self.blks           = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i), DecoderBlock(
                key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense          = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights  # decoder self-attentio 权重
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights  # encoder-decoder attention 权重
        return self.dense(X), state

    def attention_weights(self):
        return self._attention_weights

# 训练

### train函数

In [18]:
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带遮蔽的softmax交叉熵损失函数"""
    # pred的形状：(batch_size,num_steps,vocab_size)
    # label的形状：(batch_size,num_steps)
    # valid_len的形状：(batch_size,)
    def forward(self, pred, label, valid_len):
        weights = sequence_mask(torch.ones_like(label), valid_len)
        self.reduction='none'
        unweighted_loss = super().forward(pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss


def xavier_init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)


def grad_clipping(net, theta=1):
    """裁剪梯度"""
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [19]:
def train_transformer(net, data_iter, lr, num_epochs, vocab, device=try_gpu()):
    """训练transformer模型"""
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
    loss = MaskedSoftmaxCELoss()
    net.train()
    for epoch in range(num_epochs):
        total_loss, total_num = 0.0, 0.0  # 训练损失总和，词元数量
        for batch in data_iter:
            X, Z, Y, X_valid_len, Y_valid_len = [x.to(device) for x in batch]
            decoder_input = Z
            encoder_input = X
            optimizer.zero_grad()
            Y_hat, _ = net(encoder_input, decoder_input, X_valid_len)
            l = loss(Y_hat, Y.type(torch.LongTensor).to(device), Y_valid_len)
            l.sum().backward()
            grad_clipping(net, 1)
            optimizer.step()
            with torch.no_grad():
                total_loss += l.sum()
                total_num  += Y_valid_len.sum()
        if epoch==0 or (epoch + 1) % 5 == 0:
            print(f'Epoch: {epoch + 1}, loss: {total_loss / total_num:.3f}')
    print(f'final loss {total_loss / total_num:.3f}')

## 训练

### 加载数据集

In [20]:
batch_size = 256
num_steps = 15
train_iter, vocab = load_data_translation('cmn2.txt', batch_size, num_steps)
source_vocab, target_vocab = vocab[0], vocab[1]

Loading data...
Total sentence number: 26659
Total selected sentence number: 31166
The size of english vocab is 6418.
The size of chinese vocab is 2712.
Successfully finished!


In [21]:
batch_size = 256
num_steps = 50
train_iter, vocab = load_data_translation('cmn3.txt', batch_size, num_steps)
source_vocab, target_vocab = vocab[0], vocab[1]

Loading data...


FileNotFoundError: [Errno 2] No such file or directory: 'cmn3.txt'

In [None]:
batch_size = 256
num_steps = 64
a = 0.6
train_iter, vocab = load_data_translation('cmn3.txt', batch_size, num_steps, file_range = [0,a])
source_vocab, target_vocab = vocab[0], vocab[1]

In [None]:
batch_size = 256
num_steps = 80
train_iter, vocab = load_data_translation('cmn4.txt', batch_size, num_steps)
source_vocab, target_vocab = vocab[0], vocab[1]

### 模型

结构1

In [21]:
key_size, query_size, value_size = 32, 32, 32  # k/q/v的大小
num_hiddens, num_layers, dropout= 32, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
norm_shape = [32]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens, 
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens, 
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

结构2

In [None]:
key_size, query_size, value_size = 128, 128, 128  # k/q/v的大小
num_hiddens, num_layers, dropout= 128, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 128, 256, 16
norm_shape = [128]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens, 
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens, 
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

结构3

In [None]:
key_size, query_size, value_size = 256, 256, 256  # k/q/v的大小
num_hiddens, num_layers, dropout= 256, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 256, 512, 16
norm_shape = [256]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

结构4

In [None]:
key_size, query_size, value_size = 512, 512, 512  # k/q/v的大小
num_hiddens, num_layers, dropout= 512, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 512, 1024, 16
norm_shape = [512]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

结构5

In [None]:
key_size, query_size, value_size = 1024, 1024, 1024  # k/q/v的大小
num_hiddens, num_layers, dropout= 1024, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 1024, 2048, 32
norm_shape = [1024]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

### 训练与保存参数

保存和加载参数

In [22]:
device = try_gpu()
vocab = torch.load('transformer参数文件夹/p2pro.vocab')
source_vocab, target_vocab = vocab[0], vocab[1]

net.load_state_dict(torch.load('transformer参数文件夹/p2pro.params'))

net = torch.load('transformer参数文件夹/p2pro.model')
net.to(device);

FileNotFoundError: [Errno 2] No such file or directory: 'transformer参数文件夹/p2pro.vocab'

训练

In [23]:
device = try_gpu()
lr, num_epochs  = 0.005, 200
train_transformer(net, train_iter, lr, num_epochs, target_vocab, device)

Epoch: 1, loss: 0.287
Epoch: 5, loss: 0.175
Epoch: 10, loss: 0.146
Epoch: 15, loss: 0.127
Epoch: 20, loss: 0.113
Epoch: 25, loss: 0.105
Epoch: 30, loss: 0.099
Epoch: 35, loss: 0.095
Epoch: 40, loss: 0.091
Epoch: 45, loss: 0.089
Epoch: 50, loss: 0.088
Epoch: 55, loss: 0.087
Epoch: 60, loss: 0.085
Epoch: 65, loss: 0.084
Epoch: 70, loss: 0.083
Epoch: 75, loss: 0.083
Epoch: 80, loss: 0.082
Epoch: 85, loss: 0.082
Epoch: 90, loss: 0.081
Epoch: 95, loss: 0.081
Epoch: 100, loss: 0.080
Epoch: 105, loss: 0.080
Epoch: 110, loss: 0.080
Epoch: 115, loss: 0.079
Epoch: 120, loss: 0.079
Epoch: 125, loss: 0.079
Epoch: 130, loss: 0.078
Epoch: 135, loss: 0.078
Epoch: 140, loss: 0.079
Epoch: 145, loss: 0.078
Epoch: 150, loss: 0.078
Epoch: 155, loss: 0.078
Epoch: 160, loss: 0.078
Epoch: 165, loss: 0.077
Epoch: 170, loss: 0.077
Epoch: 175, loss: 0.077
Epoch: 180, loss: 0.077
Epoch: 185, loss: 0.077
Epoch: 190, loss: 0.077
Epoch: 195, loss: 0.077
Epoch: 200, loss: 0.077
final loss 0.077


保存参数和加载参数

In [26]:
torch.save(vocab, 'p2.vocab')
torch.save(net, 'p2.model')
torch.save(net.state_dict(), 'p2.params')

# 预测

In [30]:
def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) > num_steps:
        return line[:num_steps]  # 截断
    return line + [padding_token] * (num_steps - len(line))  # 填充

In [32]:
def predict(net, source_sentence, source_vocab, target_vocab, num_steps=15, device=try_gpu(), save_attention_weights=False):
    """序列到序列模型的预测"""
    net.eval()  # 在预测时将net设置为评估模式
    source_tokens = [source_vocab[i] for i in re.compile(r'\w+|\,|\.|\?').findall(source_sentence.lower())] + [source_vocab['<eos>']]
    encoder_valid_len = torch.tensor([len(source_tokens)], device=device)
    source_tokens = truncate_pad(source_tokens, num_steps, source_vocab['<pad>'])
    
    encoder_X       = torch.unsqueeze(torch.tensor(source_tokens, dtype=torch.long, device=device), dim=0)  # 添加批量轴
    encoder_outputs = net.encoder(encoder_X,  encoder_valid_len)
    decoder_state   = net.decoder.init_state(encoder_outputs, encoder_valid_len)
    decoder_X       = torch.unsqueeze(torch.tensor([target_vocab['<bos>']], dtype=torch.long, device=device), dim=0)  # 添加批量轴

    def get_key(val):
        for key, value in target_vocab.items():
            if val == value:
                return key
            
    output_seq, attention_weight_seq = [], []
    for i in range(num_steps):
        Y, decoder_state = net.decoder(decoder_X, decoder_state)
        decoder_X = Y.argmax(dim=2)  # 使用具有预测最高可能性的词元，作为解码器在下一时间步的输入
        pred = decoder_X.squeeze(dim=0).type(torch.int32).item()   
        if save_attention_weights:  # 保存注意力权重
            attention_weight_seq.append(net.decoder.attention_weights)
        if pred == target_vocab['<eos>']:  # 预测出eos, 预测结束
            break
        output_seq.append(pred)
        output_sentence = [get_key(i) for i in output_seq]
    return ''.join(output_sentence), attention_weight_seq

In [33]:
def bleu(pred_seq, label_seq, k): 
    """计算BLEU"""
    # 指数
    pred_tokens, label_tokens = list(pred_seq), list(label_seq)
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))

    for n in range(1, k + 1):  
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score

In [34]:
def translate(inputs, outputs, num_steps=15, save_weight = False):
    for input, output in zip(inputs, outputs):
        translation, attention_weight_seq = predict(net, input, source_vocab, target_vocab, num_steps, device=try_gpu(), save_attention_weights=save_weight)
        print("-"*25)
        print(f'English: {input}')
        print(f'Chinese: {output}')
        print(f'Predict: {translation}')
        print(f'BLEU:    {bleu(translation, output, k=2):.3f}')
    return attention_weight_seq

In [38]:
engs = ['i will dream it possible.', 
        #"college students can benefit a lot.", 
        #'Another issue is environmental degradation.', 
        #'urbanization can certainly have both positive and negative effects.',
        'that is very important.',
        "Students pay more attention to academic achievement.", 
        'we should participate in more activities.', 
        'Love is burning fire.',]
chas = ['我会梦想成真。', 
        #'大学生能够获益很多。', 
        #'另一个问题是环境恶化。', 
        #'城市化肯定会带来积极的和消极的影响。',
        '那非常重要。',
        '学生们更关注学术成就。',
        '我们应该参加更多活动。',
        '爱情是燃烧的火焰。']
attention_weights = translate(engs, chas, num_steps=30, save_weight=True)

-------------------------
English: i will dream it possible.
Chinese: 我会梦想成真。
Predict: 我会梦想要了。
BLEU:    0.711
-------------------------
English: that is very important.
Chinese: 那非常重要。
Predict: 那是非常重要。
BLEU:    0.837
-------------------------
English: Students pay more attention to academic achievement.
Chinese: 学生们更关注学术成就。
Predict: 学生更注意说时。
BLEU:    0.334
-------------------------
English: we should participate in more activities.
Chinese: 我们应该参加更多活动。
Predict: 我们应该参加社会参加社。
BLEU:    0.627
-------------------------
English: Love is burning fire.
Chinese: 爱情是燃烧的火焰。
Predict: 由木烧着火。
BLEU:    0.000


### 绘制注意力机制热力图

In [None]:
from matplotlib import pyplot as plt
import matplotlib_inline

def use_svg_display():
    """使用svg矢量图格式在jupyter中显示绘图"""
    matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),  cmap='Reds'):
    """显示矩阵热图"""
    use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

# 和写

In [None]:
import torch
from torch import nn

import numpy as np 


import re
import math
import collections
from matplotlib import pyplot as plt


def get_dict(filtered_pairs):
    "用已知的‘英文-中文’生成字典"
    en_vocab={}
    cn_vocab={}
    en_vocab['<pad>'],en_vocab['<bos>'],en_vocab['<eos>']=0,1,2
    cn_vocab['<pad>'],cn_vocab['<bos>'],cn_vocab['<eos>']=0,1,2
    en_idx, cn_idx = 3, 3
    for en, cn in filtered_pairs:
        for w in en:
            if w not in en_vocab:
                en_vocab[w]=en_idx
                en_idx+=1
        for w in cn:
            if w not in cn_vocab:
                cn_vocab[w]=cn_idx
                cn_idx+=1
    print(f'The size of english vocab is {len(en_vocab)}.')
    print(f'The size of chinese vocab is {len(cn_vocab)}.')
    return en_vocab, cn_vocab


def load_data_translation(file_name='cmn2.txt', batch_size=2, steps = 15, is_train = True, file_range=[0,1]):
    # 打开数据
    print("Loading data...")
    lines = open(file_name,encoding='utf-8').read().strip().split('\n')
    begin = int(len(lines)*file_range[0])
    end   = int(len(lines)*file_range[1])
    
    # 从txt文件中读取出数据，并成对地保存在pairs中
    words_re=re.compile(r'[a-zA-Z]+|\,|\.|\?|^[0-9]*$')
    pairs, filtered_pairs = [], []
    for l in lines:
        en_sentence, cn_sentence = l.split('\t',1)
        pairs.append((words_re.findall(en_sentence.lower()), list(cn_sentence)))

    # 筛选出来常用的和简短的词语
    MAX_LEN = steps
    for x in pairs:
        if len(x[0])<MAX_LEN and len(x[1])<MAX_LEN:
            filtered_pairs.append(x) 
    print(f'Total sentence number: {len(filtered_pairs)}') 
    print(f'Total selected sentence number: {end - begin}') 

    # 建立字典
    en_vocab, cn_vocab = get_dict(filtered_pairs)
    vocab = [en_vocab, cn_vocab]

    print("Processing data...")
    # 将句子中的单词/字编码
    padded_en_sents, padded_cn_sents, padded_cn_label_sents=[], [], []
    for en,cn in filtered_pairs[begin:end]:
        padded_en_sent=en+["<eos>"]+["<pad>"]*(MAX_LEN-len(en))
        padded_cn_sent=["<bos>"]+cn+["<eos>"]+["<pad>"]*(MAX_LEN-len(cn))
        padded_cn_label_sent=cn+['<eos>']+['<pad>']*(MAX_LEN-len(cn)+1)

        padded_en_sents.append([en_vocab[w] for w in padded_en_sent])
        padded_cn_sents.append([cn_vocab[w] for w in padded_cn_sent])
        padded_cn_label_sents.append([cn_vocab[w] for w in padded_cn_label_sent])

    train_en_sents= torch.tensor(np.array(padded_en_sents))
    train_cn_sents= torch.tensor(np.array(padded_cn_sents))
    train_cn_label_sents=torch.tensor(np.array(padded_cn_label_sents))
    en_valid_len = (train_en_sents != en_vocab['<pad>']).type(torch.int32).sum(1)
    cn_valid_len = (train_cn_sents != cn_vocab['<pad>']).type(torch.int32).sum(1)

    # 转化成dataloader的形式
    data_arrays = [train_en_sents, train_cn_sents, train_cn_label_sents, en_valid_len, cn_valid_len]
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)
    print('Successfully finished!')

    return dataloader, vocab

def try_gpu():
    """如果存在则返回gpu, 否则返回cpu"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')


class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError


class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError


class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_X, decoder_X, *args):
        enc_outputs = self.encoder(encoder_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(decoder_X, dec_state)
    

def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


def masked_softmax(X, valid_lens):  # X:3D张量，valid_lens:1D或2D张量
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
            # torch.repeat_interleave(input, repeats, dim=None) → Tensor
        else:
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) # 最后一轴上被掩蔽的元素使用非常大的负值替换，从而其softmax输出为0
        return nn.functional.softmax(X.reshape(shape), dim=-1)
    

class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状：(batch_size，查询的个数，d)，              keys的形状：(batch_size，“键－值”对的个数，d)
    # values的形状： (batch_size，“键－值”对的个数，值的维度)，  valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)  # 设置transpose_b=True为了交换keys的最后两个维度
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
    

def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries，keys，values的形状: (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # 经过变换后，输出的queries，keys，values　的形状: (batch_size*num_heads，查询或者“键－值”对的个数，num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:  # valid_lens　的形状: (batch_size，)或(batch_size，查询的个数)
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)  # 在轴0，将第一项（标量或者矢量）复制num_heads次，然后如此复制第二项，然后诸如此类。
        output = self.attention(queries, keys, values, valid_lens)  # output的形状:(batch_size*num_heads，查询的个数，num_hiddens/num_heads)
        output_concat = transpose_output(output, self.num_heads)  # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        return self.W_o(output_concat)
    

class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        # self.relu = nn.ReLU()
        self.relu = nn.LeakyReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))
    

class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)
    


class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        a = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
        b = torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P = torch.zeros((1, max_len, num_hiddens))
        self.P[:, :, 0::2] = torch.sin(a / b)
        self.P[:, :, 1::2] = torch.cos(a / b)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
    


class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, 
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention  = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1   = AddNorm(norm_shape, dropout)
        self.ffn        = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2   = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):  # 自注意力+残差连接
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))  # 位置前馈+残差连接
    

class TransformerEncoder(Encoder):
    """Transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, 
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens    = num_hiddens
        self.embedding      = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding   = PositionalEncoding(num_hiddens, dropout)
        self.blks           = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i), EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        embedding = self.embedding(X) * math.sqrt(self.num_hiddens)# 因为位置编码值在-1和1之间，因此嵌入值乘以嵌入维度的平方根进行缩放，然后再与位置编码相加。
        X = self.pos_encoding(embedding) 
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X
    



class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, 
                 num_heads,dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1   = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2   = AddNorm(norm_shape, dropout)
        self.ffn        = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3   = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:  # 训练阶段，输出序列的所有词元都在同一时间处理， 因此state[2][self.i]初始化为None。
            key_values = X
        else:  # 预测阶段，输出序列是通过词元一个接着一个解码的，因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示。
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape  # dec_valid_lens的开头:(batch_size,num_steps), 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y  = self.addnorm1(X, X2)
        
        # 编码器－解码器注意力
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)  # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Z  = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state
    


class AttentionDecoder(Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError
    


class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens    = num_hiddens
        self.num_layers     = num_layers
        self.embedding      = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding   = PositionalEncoding(num_hiddens, dropout)
        self.blks           = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i), DecoderBlock(
                key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense          = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights  # decoder self-attentio 权重
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights  # encoder-decoder attention 权重
        return self.dense(X), state

    def attention_weights(self):
        return self._attention_weights
    




def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带遮蔽的softmax交叉熵损失函数"""
    # pred的形状：(batch_size,num_steps,vocab_size)
    # label的形状：(batch_size,num_steps)
    # valid_len的形状：(batch_size,)
    def forward(self, pred, label, valid_len):
        weights = sequence_mask(torch.ones_like(label), valid_len)
        self.reduction='none'
        unweighted_loss = super().forward(pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).sum(dim=1)
        return weighted_loss


def xavier_init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)


def grad_clipping(net, theta=1):
    """裁剪梯度"""
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm




def train_transformer(net, data_iter, lr, num_epochs, vocab, device=try_gpu(), weight_decay=1e-5):
    """训练transformer模型"""
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    loss = MaskedSoftmaxCELoss()
    net.train()
    x_plot, y_plot = [],[]
    for epoch in range(num_epochs):
        print(f'------------------------Epoch {epoch} -------------------------')
        total_loss, total_num = 0.0, 0.0  # 训练损失总和，词元数量
        total_batch = len(data_iter)
        for i_batch, batch in enumerate(data_iter):
            X, Z, Y, X_valid_len, Y_valid_len = [x.to(device) for x in batch]
            decoder_input = Z
            encoder_input = X
            optimizer.zero_grad()
            Y_hat, _ = net(encoder_input, decoder_input, X_valid_len)
            l = loss(Y_hat, Y.type(torch.LongTensor).to(device), Y_valid_len)
            l.sum().backward()
            grad_clipping(net, 1)
            optimizer.step()
            with torch.no_grad():
                total_loss += l.sum()
                total_num  += Y_valid_len.sum()
            if i_batch % 200 == 0:
                print(f'Batch {i_batch}({100*i_batch/total_batch:.2f}%), loss: {total_loss / total_num:.3f}')
                x_plot.append(i_batch)
                y_plot.append(total_loss.to('cpu') / total_num.to('cpu'))
            if i_batch % 1000 == 0:
                plt.plot(x_plot, y_plot)
        if epoch==0 or (epoch + 1) % 5 == 0:
            print(f'Epoch: {epoch + 1}, loss: {total_loss / total_num:.3f}')
    print(f'final loss {total_loss / total_num:.3f}')


def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) > num_steps:
        return line[:num_steps]  # 截断
    return line + [padding_token] * (num_steps - len(line))  # 填充


def predict(net, source_sentence, source_vocab, target_vocab, num_steps=15, device=try_gpu(), save_attention_weights=False):
    """序列到序列模型的预测"""
    net.eval()  # 在预测时将net设置为评估模式
    source_tokens = [source_vocab[i] for i in re.compile(r'\w+|\,|\.|\?').findall(source_sentence.lower())] + [source_vocab['<eos>']]
    encoder_valid_len = torch.tensor([len(source_tokens)], device=device)
    source_tokens = truncate_pad(source_tokens, num_steps, source_vocab['<pad>'])
    
    encoder_X       = torch.unsqueeze(torch.tensor(source_tokens, dtype=torch.long, device=device), dim=0)  # 添加批量轴
    encoder_outputs = net.encoder(encoder_X,  encoder_valid_len)
    decoder_state   = net.decoder.init_state(encoder_outputs, encoder_valid_len)
    decoder_X       = torch.unsqueeze(torch.tensor([target_vocab['<bos>']], dtype=torch.long, device=device), dim=0)  # 添加批量轴

    def get_key(val):
        for key, value in target_vocab.items():
            if val == value:
                return key
            
    output_seq, attention_weight_seq = [], []
    for i in range(num_steps):
        Y, decoder_state = net.decoder(decoder_X, decoder_state)
        decoder_X = Y.argmax(dim=2)  # 使用具有预测最高可能性的词元，作为解码器在下一时间步的输入
        pred = decoder_X.squeeze(dim=0).type(torch.int32).item()   
        if save_attention_weights:  # 保存注意力权重
            attention_weight_seq.append(net.decoder.attention_weights)
        if pred == target_vocab['<eos>']:  # 预测出eos, 预测结束
            break
        output_seq.append(pred)
        output_sentence = [get_key(i) for i in output_seq]
    return ''.join(output_sentence), attention_weight_seq


def bleu(pred_seq, label_seq, k): 
    """计算BLEU"""
    # 指数
    pred_tokens, label_tokens = list(pred_seq), list(label_seq)
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))

    for n in range(1, k + 1):  
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score



def translate(inputs, outputs, num_steps=15, save_weight = False):
    for input, output in zip(inputs, outputs):
        translation, attention_weight_seq = predict(net, input, source_vocab, target_vocab, num_steps, device=try_gpu(), save_attention_weights=save_weight)
        print("-"*25)
        print(f'English: {input}')
        print(f'Chinese: {output}')
        print(f'Predict: {translation}')
        print(f'BLEU:    {bleu(translation, output, k=2):.3f}')
    return attention_weight_seq


In [None]:
batch_size = 128
num_steps = 64
a = 0.6
train_iter, vocab = load_data_translation('cmn4.txt', batch_size, num_steps, file_range = [0,a])
source_vocab, target_vocab = vocab[0], vocab[1]

In [None]:
key_size, query_size, value_size = 1024, 1024, 1024  # k/q/v的大小
num_hiddens, num_layers, dropout= 1024, 2, 0.1
ffn_num_input, ffn_num_hiddens, num_heads = 1024, 2048, 32
norm_shape = [1024]

encoder = TransformerEncoder(len(source_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(target_vocab), key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.apply(xavier_init_weights);

In [None]:
device = try_gpu()
lr, num_epochs  = 0.0001, 1
train_transformer(net, train_iter, lr, num_epochs, target_vocab, device, weight_decay = 2e-6)