In [106]:
import os
import torch
import requests
import matplotlib.pyplot as plt
from torch import nn
from utils.useful_func import *

In [71]:
def read_data_nmt():
    with open(r'../data/fra-eng/fra.txt','r',encoding='utf-8') as f:
        result=f.read()
    return result

In [2]:
#@save
def preprocess_nmt(text):
    """预处理“英语－法语”数据集"""
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # 使用空格替换不间断空格
    # 使用小写字母替换大写字母
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # 在单词和标点符号之间插入空格
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)

text = preprocess_nmt(result)

In [77]:
def tokenize_nmt(text,num_examples=None):
    source,target=[],[]
    for i,line in enumerate(text.split('\n')):
        if num_examples and i>=num_examples:
            break
        parts=line.split('\t')
        ## 必须得有原始和目标值所以长度得为2
        if len(parts)==2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source,target

In [78]:
#@save
def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) >= num_steps:
        return line[:num_steps]
    return line+[padding_token]*(num_steps-len(line))

In [79]:
## 将文本数据作为批量数据，并且加入<eos>至末尾 同时统计有效字符 包含eos
def build_array_nmt(lines, vocab, num_steps):
    lines=[src_vocab[i] for i in lines]
    lines=[i+[vocab['<eos>']] for i in lines]
    array=torch.tensor([truncate_pad(i,num_steps,vocab['<pad>']) for i in lines])
    valid_len =(array != vocab['<pad>']).type(torch.int32).sum(dim=1)
    return array, valid_len

In [103]:
def load_data_nmt(batch_size, num_steps, num_examples=600):
    ## 结果是原始列表 列表表长度 目标列表 目标列表长度
    source, target = tokenize_nmt(text, num_examples)
    src_vocab = Vocal(source, min_feq=2,
                              reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = Vocal(target, min_feq=2,
                              reserved_tokens=['<pad>', '<bos>', '<eos>'])
    
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True),src_vocab, tgt_vocab

In [104]:
train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)

In [105]:
for X, X_valid_len, Y, Y_valid_len in train_iter:
    print('X:', X.type(torch.int32))
    print('X的有效长度:', X_valid_len)
    print('Y:', Y.type(torch.int32))
    print('Y的有效长度:', Y_valid_len)
    break

X: tensor([[2944,    4,    3,    1,    1,    1,    1,    1],
        [  12,  189,    4,    3,    1,    1,    1,    1]], dtype=torch.int32)
X的有效长度: tensor([3, 4])
Y: tensor([[  0, 126,   3,   1,   1,   1,   1,   1],
        [ 12,   0,   0,   4,   3,   1,   1,   1]], dtype=torch.int32)
Y的有效长度: tensor([3, 5])


In [138]:
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    
    def init_state(self,enc_outputs,*args):
        raise NotImplementedError
    
    def forward(self,X,state):
        raise NotImplementedError

In [139]:
class Encoder(nn.Module):
    """编码器—解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)
    def forward(self, X,*args):
        raise NotImplementedError

class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,**kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X)
        dec_state=self.decoder.init(enc_outputs,*args)
        return self.decoder(dec_X,dec_state)
        

In [140]:
import collections
import math
import torch
from torch import nn
from d2l import torch as d2l

In [194]:
class Seq2SeqEncoder(Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
        
    def forward(self, X, *args):
        # 输出'X'的形状：(batch_size,num_steps,embed_size)
        X = self.embedding(X)
        # 在循环神经网络模型中，第一个轴对应于时间步
        X = X.permute(1, 0, 2)
        ## 无初始状态则初始状态为0
        output, state = self.rnn(X)
        return output, state
        

class Seq2SeqDecoder(Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size + num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)
        
    def init_state(self,enc_outputs,*args):
        return enc_outputs[1]
    
    def forward(self,X,dec_state):
        X=self.embedding(X).permute(1, 0, 2)
        # 广播context，使state最后时刻最后一层 具有与X相同的num_steps
        context=dec_state[-1].repeat(X.shape[0],1,1)
        X_and_context=torch.cat((X,context),dim=-1)
        output,state=self.rnn(X_and_context,dec_state)
        # output的形状:(batch_size,num_steps,vocab_size)
        output=self.dense(output).permute(1,0,2)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output,state
        

In [195]:
encoder=Seq2SeqEncoder(10,8,16,2)
encoder.eval()
X=torch.zeros((4,7),dtype=torch.long)
output,state=encoder(X)
## 7是时间步 4是batch_size, 16是隐藏层维度
output.shape
## 层数 batch_size 隐藏层维度
state.shape

torch.Size([2, 4, 16])

In [196]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)

In [197]:
dec_state=decoder.init_state(encoder(X))

In [199]:
output,state=decoder.forward(X,dec_state)

In [201]:
output.shape

torch.Size([4, 7, 10])

In [202]:
state.shape

torch.Size([2, 4, 16])

损失函数

In [None]:
def sequence_mask(X,valid_len,value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.shape[0]