In [312]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.useful_func import *
from torch import optim

# pack_padded_sequence打包需要mask的方法
# import torch
# from torch.nn.utils.rnn import pack_padded_sequence
# 
# # 输入数据：批次大小 3，序列最大长度 5，词向量维度 2
# padded_sequences = torch.randn(5, 3, 2)  # shape (seq_len, batch_size, input_size)
# lengths = torch.tensor([5, 3, 2], dtype=torch.long)  # 实际长度（已降序排列）
# 
# # 打包序列
# packed = pack_padded_sequence(
#     input=padded_sequences,
#     lengths=lengths,
#     batch_first=False,
#     enforce_sorted=True
# )

In [360]:
src, tgt = tokenize_nmt(preprocess_nmt(read_data_nmt()),num_examples=600)
src_vocab = Vocal(src, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocal(tgt, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_data, src_valid = build_array_nmt(src, src_vocab, 10)
tgt_data, tgt_valid = build_array_nmt(tgt, tgt_vocab, 10)
dataset = torch.utils.data.TensorDataset(src_data, src_valid, tgt_data, tgt_valid)
## 训练数据
train_data = torch.utils.data.DataLoader(dataset=dataset, batch_size=64, shuffle=False)

class Encoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Encoder, self).__init__()
        pass

    def forward(self, X, *args, **kwargs):
        raise NotImplementedError


class Decoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Decoder, self).__init__()
        pass

    def init_state(self, enc_outputs, *args, **kwargs):
        raise NotImplementedError

    def forward(self, X, state, *args, **kwargs):
        raise NotImplementedError


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, *args, **kwargs):
        super(EncoderDecoder, self).__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_x, dec_x, *args, **kwargs):
        enc_outputs = self.encoder(enc_x, *args, **kwargs)
        dec_state = self.decoder.init_state(enc_outputs, *args, **kwargs)
        return self.decoder(dec_x, dec_state, *args, **kwargs)

class Seq2SeqEncoder(Encoder):
    """输入一个x batch_size*num_steps或num_steps*batch_size 
    输出output batch_size num_steps 
    隐藏状态 numlayers * batch_size * hidden_size"""

    def __init__(self, vocab_size, embed_size, hidden_size,num_layers, dropout=0.1, *args, **kwargs):
        super(Seq2SeqEncoder, self).__init__(*args, **kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, hidden_size,num_layers, dropout=dropout, batch_first=True)

    def forward(self, X, batchfirst=True, *args, **kwargs):
        embed_x = self.embed(X)

        outputs, state = self.rnn(embed_x)
        return outputs, state

class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, hidden_size,num_layers, dropout=0.1, *args, **kwargs):
        super(Seq2SeqDecoder, self).__init__(*args, **kwargs)
        self.embed=nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size+hidden_size, hidden_size,num_layers, dropout=dropout, batch_first=True)
        self.dense=nn.Linear(hidden_size, vocab_size)

    def init_state(self, enc_outputs, *args, **kwargs):
        # outputs 形状为 batch_size,num_steps,embed_size
        # state 形状为 num_layers batch_size hidden_size
        state=enc_outputs[1]
        return state

    def forward(self, dec_x,state, *args, **kwargs):
        # batch_size numteps emed_size
        dec_x = self.embed(dec_x)
        # state 为 batch_size hidden_size
        context = state[-1].unsqueeze(1).repeat(1,dec_x.shape[1],1)
        inputs=torch.cat([dec_x, context], dim=-1)
        outputs, state = self.rnn(inputs,state)
        seq=self.dense(outputs)

        return seq, state


In [404]:
# train
encoder=Seq2SeqEncoder(vocab_size=src_vocab.__len__(),embed_size=32,hidden_size=32,num_layers=2)
decoder=Seq2SeqDecoder(vocab_size=tgt_vocab.__len__(),embed_size=32,hidden_size=32,num_layers=2)
net=EncoderDecoder(encoder,decoder)
src_vocab.__len__(),tgt_vocab.__len__()

# 初始化模型参数
def xavier_init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.GRU:
        for name,param in m.named_parameters():
            if "weight" in name:
                nn.init.xavier_uniform_(param)

net.apply(xavier_init_weights)

def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    # 首先x是二维的 最内层维度是句子长度 注意：是训练集所以才知道句子真实长度
    # 拿出总的长度 得到长度
    maxlen = X.shape[1]
    # 然后用总长度生成一个1维的向量 使用函数扩展成2维以便与valid_len进行广播
    mask = torch.unsqueeze(torch.arange(0, maxlen, dtype=torch.long), dim=0)
    # mask在0维度扩充 valid在1维度扩充 因为每一个valid对应的是每一个x valid的数字其实是x的第二维向量
    mask = (mask < torch.unsqueeze(valid_len, dim=1))  # 这里小于号就够了 因为<eos>所在位置的索引其实是valid_len-1
    X[~mask] = value
    return X


# 拓展的softmax因为对填充值进行softmax其实没有什么意义
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        pred = pred.permute(0, 2, 1)
        # 交叉熵损失期望的两个输入 x是 batch_size vocab_size seq_lenth 
        # y 是 batch_size seq_len
        unweight_loss = super().forward(pred, label)
        weights_loss = unweight_loss * weights
        return weights_loss.mean(dim=1)

loss = MaskedSoftmaxCELoss()



In [405]:
# train
optimizer = optim.Adam(net.parameters(), lr=0.005)
for epoch in range(300):
    for batch in train_data:
        optimizer.zero_grad()
        
        src,src_valid,tgt,tgt_valid=batch
        Y=torch.cat((torch.tensor([tgt_vocab['<bos>']]).unsqueeze(0).repeat(tgt.shape[0],1),tgt),dim=1)[:,:-1]
        # 易错点1：训练的时候应该使用带有bos的Y 表示强制教学 此时输出的y_hat实际上是没有bos的
        y_hat,_=net(src,Y,src_valid)
        # 点2 因为输出的y_hat 没有bos 因此在计算loss的时候应该使用原始序列作为目标序列
        l=loss(y_hat,tgt,tgt_valid).sum()
        l.backward()
        grad_clipping(net, 1)
        optimizer.step()
    print(l)

tensor(52.7925, grad_fn=<SumBackward0>)
tensor(44.9963, grad_fn=<SumBackward0>)
tensor(41.9487, grad_fn=<SumBackward0>)
tensor(40.0511, grad_fn=<SumBackward0>)
tensor(38.0886, grad_fn=<SumBackward0>)
tensor(35.7034, grad_fn=<SumBackward0>)
tensor(34.1993, grad_fn=<SumBackward0>)
tensor(32.3679, grad_fn=<SumBackward0>)
tensor(30.4093, grad_fn=<SumBackward0>)
tensor(28.7496, grad_fn=<SumBackward0>)
tensor(27.1362, grad_fn=<SumBackward0>)
tensor(25.7173, grad_fn=<SumBackward0>)
tensor(24.3811, grad_fn=<SumBackward0>)
tensor(23.1102, grad_fn=<SumBackward0>)
tensor(21.8070, grad_fn=<SumBackward0>)
tensor(20.7296, grad_fn=<SumBackward0>)
tensor(19.9570, grad_fn=<SumBackward0>)
tensor(19.3106, grad_fn=<SumBackward0>)
tensor(18.4051, grad_fn=<SumBackward0>)
tensor(17.4071, grad_fn=<SumBackward0>)
tensor(17.2090, grad_fn=<SumBackward0>)
tensor(16.4874, grad_fn=<SumBackward0>)
tensor(15.6214, grad_fn=<SumBackward0>)
tensor(15.0747, grad_fn=<SumBackward0>)
tensor(14.5523, grad_fn=<SumBackward0>)


In [408]:
# predict
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for src_sentence,tgt_sentence in zip(engs,fras):
    src_tokens=[src_vocab[i] for i in src_sentence.split(' ')]+[src_vocab['<eos>']]
    src_data=truncate_pad(src_tokens,10,1)
    # 易错点 srcdata在生成完成后需要unsqueeze 因为原本是1维的需要增加一个批次维度
    src_data=torch.tensor(src_data).unsqueeze(0)
    # 易错点 需要加上这个批次的有效长度建议1维 或者无维度
    enc_valid_len=torch.tensor([len(src_tokens)])
    
    enc_outputs=net.encoder(src_data,enc_valid_len)
    state=net.decoder.init_state(enc_outputs,enc_valid_len)
    # 易错点 这里必须得是long 因为embed层需要long输入 并且要unsequeeze 增加一个维度
    dec_x=torch.tensor([tgt_vocab['<bos>']],dtype=torch.long).unsqueeze(0)

    output_list=[]
    for i in range(10):
        output,state=net.decoder(dec_x,state)
        # 易错点 需要用这一次的输出argmax之后 作为下一次的输入 因为输入必须得是long
        dec_x=torch.argmax(output,dim=-1)

        output_list.append(dec_x.squeeze(0).squeeze(0))
        if tgt_vocab['<eos>']==torch.argmax(output,dim=-1).squeeze(0):
            break
    print([tgt_vocab.idx_to_token[i] for i in output_list])

['va', 'au', 'unk', '!', '<eos>']
["j'ai", 'unk', '.', '<eos>']
['il', 'est', 'riche', 'unk', '.', '<eos>']
['je', 'suis', 'chez', 'bonne', 'aboient', '?', '?', 'aboient', 'aboient', 'aboient']
