In [17]:
import torch
from torch import nn
from torch import optim

src, tgt = tokenize_nmt(preprocess_nmt(read_data_nmt()),num_examples=600)
src_vocab = Vocal(src, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocal(tgt, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_data, src_valid = build_array_nmt(src, src_vocab, 10)
tgt_data, tgt_valid = build_array_nmt(tgt, tgt_vocab, 10)
dataset = torch.utils.data.TensorDataset(src_data, src_valid, tgt_data, tgt_valid)
## 训练数据
train_data = torch.utils.data.DataLoader(dataset=dataset, batch_size=64, shuffle=False)

# Bahdanau 注意力 只需要修改decoder部分

In [18]:
class Encoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Encoder, self).__init__()
        pass

    def forward(self, X, *args, **kwargs):
        raise NotImplementedError


class Decoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Decoder, self).__init__()
        pass

    def init_state(self, enc_outputs, *args, **kwargs):
        raise NotImplementedError

    def forward(self, X, state, *args, **kwargs):
        raise NotImplementedError


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, *args, **kwargs):
        super(EncoderDecoder, self).__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_x, dec_x, *args, **kwargs):
        enc_outputs = self.encoder(enc_x, *args, **kwargs)
        dec_state = self.decoder.init_state(enc_outputs, *args, **kwargs)
        return self.decoder(dec_x, dec_state, *args, **kwargs)

class Seq2SeqEncoder_hmy(Encoder):
    """输入一个x batch_size*num_steps或num_steps*batch_size 
    输出output batch_size num_steps 
    隐藏状态 numlayers * batch_size * hidden_size"""

    def __init__(self, vocab_size, embed_size, hidden_size,num_layers, dropout=0.1, *args, **kwargs):
        super(Seq2SeqEncoder_hmy, self).__init__(*args, **kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size,num_layers, dropout=dropout, batch_first=True)

    def forward(self, X, *args, **kwargs):
        embed_x = self.embed(X)
        outputs, state = self.rnn(embed_x)
        return outputs, state
    

class AdditiveAttention_hmy(nn.Module):
    """加性注意力"""
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
        super(AdditiveAttention_hmy,self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=False)
        self.W_v = nn.Linear(num_hiddens,1,bias=False)
        self.dropout = nn.Dropout(dropout)
    # queries 维度应该是 batch_size * 要查询的数量 * q_size向量长度
    # keys 维度是 batch_size * keys的数量（key-value)键值对 * key向量长度
    # values与key相等 value_size可以不一样
    def forward(self, queries, keys, values, valid_lens): ## valid_len从输入来的 屏蔽掉填充部分
        queries,keys=self.W_q(queries),self.W_k(keys)
        queries=queries.unsqueeze(2)
        keys=keys.unsqueeze(1)
        features = queries + keys
        features = torch.tanh(features)
        scores = self.W_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        attention_temp=self.dropout(self.attention_weights)
        return torch.bmm(attention_temp, values)
    
    
class AttentionDecoder(Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self,**kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)
    @property
    def attention_weight(self):
        raise NotImplementedError

In [25]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    """通过encoder传来的原始序列的编码信息 进行解码翻译"""
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention=AdditiveAttention_hmy(num_hiddens,num_hiddens,num_hiddens,dropout=dropout)
        
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.LSTM(embed_size+num_hiddens,num_hiddens,num_layers,batch_first=True,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)

    def init_state(self,encoder_out,enc_valid_lens,*args):
        outputs,hidden_state=encoder_out
        return outputs,hidden_state,enc_valid_lens

    def forward(self,X,state,*args,**kwargs):
        #ouputs batch_size,num_steps,num_hiddens
        enc_outputs, hidden_state, enc_valid_lens = state
        (H,C)=hidden_state
        X=self.embedding(X).permute(1,0,2)
        dec_outputs, self._attention_weights = [], []
        for x in X:
            #，unsqueeze(1) 这一步的目的就是为了给 Decoder 的顶层隐藏状态显式地添加一个维度，用来表示 Query 的数量 (在这个时间步是 1)，
            # 从而使其形状符合 Attention 模块期望的 (batch_size, num_queries, feature_size) 输入格式，使得 Attention 模块可以正确地进行批处理和内部计算。
            # query的形状为(batch_size,1,num_hiddens)
            query=H[-1].unsqueeze(1)
            # qkv 和q的有效长度
            # query batch_size,1,num_hiddens
            # context batch_size,num_steps,num_hiddens
            context=self.attention(query,enc_outputs,enc_outputs,enc_valid_lens)
            # x为batch_size,1,embed+hidden_size
            x=torch.cat((context,x.unsqueeze(1)),dim=-1)
            out,hidden_state=self.rnn(x,hidden_state)
            dec_outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        dec_outputs = self.dense(torch.cat(dec_outputs, dim=1))
        return dec_outputs, [enc_outputs, hidden_state,
                                          enc_valid_lens]
    @property
    def attention_weights(self):
        return self._attention_weights

In [26]:
encoder = Seq2SeqEncoder_hmy(vocab_size=10, embed_size=8, hidden_size=16,
                             num_layers=2)

decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                                  num_layers=2)

In [27]:
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([2, 4, 16]))

In [28]:
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    # 首先x是二维的 最内层维度是句子长度 注意：是训练集所以才知道句子真实长度
    # 拿出总的长度 得到长度
    maxlen = X.shape[1]
    # 然后用总长度生成一个1维的向量 使用函数扩展成2维以便与valid_len进行广播
    mask = torch.unsqueeze(torch.arange(0, maxlen, dtype=torch.long), dim=0)
    # mask在0维度扩充 valid在1维度扩充 因为每一个valid对应的是每一个x valid的数字其实是x的第二维向量
    mask = (mask < torch.unsqueeze(valid_len, dim=1))  # 这里小于号就够了 因为<eos>所在位置的索引其实是valid_len-1
    X[~mask] = value
    return X


# 拓展的softmax因为对填充值进行softmax其实没有什么意义
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        pred = pred.permute(0, 2, 1)
        # 交叉熵损失期望的两个输入 x是 batch_size vocab_size seq_lenth 
        # y 是 batch_size seq_len
        unweight_loss = super().forward(pred, label)
        weights_loss = unweight_loss * weights
        return weights_loss.mean(dim=1)

loss = MaskedSoftmaxCELoss()

In [29]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs = 0.005, 250

encoder=Seq2SeqEncoder_hmy(vocab_size=len(src_vocab), embed_size=32, hidden_size=32,num_layers=2)
decoder=Seq2SeqAttentionDecoder(vocab_size=len(tgt_vocab), embed_size=32, num_hiddens=32,num_layers=2)
net=EncoderDecoder(encoder, decoder)

In [30]:
# train
optimizer = optim.Adam(net.parameters(), lr=0.005)
for epoch in range(300):
    for batch in train_data:
        optimizer.zero_grad()
        
        src,src_valid,tgt,tgt_valid=batch
        Y=torch.cat((torch.tensor([tgt_vocab['<bos>']]).unsqueeze(0).repeat(tgt.shape[0],1),tgt),dim=1)[:,:-1]
        # 易错点1：训练的时候应该使用带有bos的Y 表示强制教学 此时输出的y_hat实际上是没有bos的
        y_hat,_=net(src,Y,src_valid)
        # 点2 因为输出的y_hat 没有bos 因此在计算loss的时候应该使用原始序列作为目标序列
        l=loss(y_hat,tgt,tgt_valid).sum()
        l.backward()
        grad_clipping(net, 1)
        optimizer.step()
    print(l)

tensor(53.0358, grad_fn=<SumBackward0>)
tensor(43.2106, grad_fn=<SumBackward0>)
tensor(41.9331, grad_fn=<SumBackward0>)
tensor(40.8712, grad_fn=<SumBackward0>)
tensor(41.0336, grad_fn=<SumBackward0>)
tensor(40.4952, grad_fn=<SumBackward0>)
tensor(39.7365, grad_fn=<SumBackward0>)
tensor(38.7306, grad_fn=<SumBackward0>)
tensor(37.9568, grad_fn=<SumBackward0>)
tensor(36.6157, grad_fn=<SumBackward0>)
tensor(36.1713, grad_fn=<SumBackward0>)
tensor(34.9458, grad_fn=<SumBackward0>)
tensor(33.9696, grad_fn=<SumBackward0>)
tensor(32.3891, grad_fn=<SumBackward0>)
tensor(30.8281, grad_fn=<SumBackward0>)
tensor(29.7276, grad_fn=<SumBackward0>)
tensor(28.6182, grad_fn=<SumBackward0>)
tensor(28.6453, grad_fn=<SumBackward0>)
tensor(27.3400, grad_fn=<SumBackward0>)
tensor(26.8034, grad_fn=<SumBackward0>)
tensor(26.1405, grad_fn=<SumBackward0>)
tensor(25.5864, grad_fn=<SumBackward0>)
tensor(24.7411, grad_fn=<SumBackward0>)
tensor(24.0848, grad_fn=<SumBackward0>)
tensor(23.7306, grad_fn=<SumBackward0>)


In [31]:
# predict
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for src_sentence,tgt_sentence in zip(engs,fras):
    src_tokens=[src_vocab[i] for i in src_sentence.split(' ')]+[src_vocab['<eos>']]
    src_data=truncate_pad(src_tokens,10,1)
    # 易错点 srcdata在生成完成后需要unsqueeze 因为原本是1维的需要增加一个批次维度
    src_data=torch.tensor(src_data).unsqueeze(0)
    # 易错点 需要加上这个批次的有效长度建议1维 或者无维度
    enc_valid_len=torch.tensor([len(src_tokens)])
    
    enc_outputs=net.encoder(src_data,enc_valid_len)
    state=net.decoder.init_state(enc_outputs,enc_valid_len)
    # 易错点 这里必须得是long 因为embed层需要long输入 并且要unsequeeze 增加一个维度
    dec_x=torch.tensor([tgt_vocab['<bos>']],dtype=torch.long).unsqueeze(0)

    output_list=[]
    for i in range(10):
        output,state=net.decoder(dec_x,state)
        # 易错点 需要用这一次的输出argmax之后 作为下一次的输入 因为输入必须得是long
        dec_x=torch.argmax(output,dim=-1)

        output_list.append(dec_x.squeeze(0).squeeze(0))
        if tgt_vocab['<eos>']==torch.argmax(output,dim=-1).squeeze(0):
            break
    print([tgt_vocab.idx_to_token[i] for i in output_list])

['va', '!', '<eos>']
["j'ai", 'perdu', '.', '<eos>']
['il', 'est', 'bon', '.', '<eos>']
['je', 'suis', 'chez', 'moi', '.', '<eos>']
