In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.useful_func import *


# pack_padded_sequence打包需要mask的方法
# import torch
# from torch.nn.utils.rnn import pack_padded_sequence
# 
# # 输入数据：批次大小 3，序列最大长度 5，词向量维度 2
# padded_sequences = torch.randn(5, 3, 2)  # shape (seq_len, batch_size, input_size)
# lengths = torch.tensor([5, 3, 2], dtype=torch.long)  # 实际长度（已降序排列）
# 
# # 打包序列
# packed = pack_padded_sequence(
#     input=padded_sequences,
#     lengths=lengths,
#     batch_first=False,
#     enforce_sorted=True
# )

In [39]:
src, tgt = tokenize_nmt(preprocess_nmt(read_data_nmt()))
src_vocab = Vocal(src, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocal(tgt, min_feq=2,
                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_data, src_valid = build_array_nmt(src, src_vocab, 10)
tgt_data, tgt_valid = build_array_nmt(tgt, tgt_vocab, 10)
dataset = torch.utils.data.TensorDataset(src_data, src_valid, tgt_data, tgt_valid)
## 训练数据
train_data = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)

In [38]:
class Encoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Encoder, self).__init__()
        pass

    def forward(self, X, *args, **kwargs):
        raise NotImplementedError


class Decoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Decoder, self).__init__()
        pass

    def init_state(self, enc_outputs, *args, **kwargs):
        raise NotImplementedError

    def forward(self, X, state, *args, **kwargs):
        raise NotImplementedError


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, *args, **kwargs):
        super(EncoderDecoder, self).__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_x, dec_x, *args, **kwargs):
        enc_outputs = self.encoder(enc_x, *args, **kwargs)
        dec_state = self.decoder.init_state(enc_outputs, *args, **kwargs)
        return self.decoder(dec_x, dec_state, *args, **kwargs)


In [40]:
class Seq2SeqEncoder(Encoder):
    """输入一个x batch_size*num_steps或num_steps*batch_size 
    输出output batch_size num_steps 
    隐藏状态 numlayers * batch_size * hidden_size"""

    def __init__(self, vocab_size, embed_size, hidden_size, dropout, *args, **kwargs):
        super(Seq2SeqEncoder, self).__init__(*args, **kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, hidden_size, dropout=dropout, batch_first=True)

    def forward(self, X, batchfirst=True, *args, **kwargs):
        embed_x = self.embed(X)
        if not batchfirst:
            embed_x = embed_x.permute(1, 0, 2)

        outputs, state = self.rnn(embed_x)
        return outputs, state


In [None]:
class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, hidden_size, dropout, *args, **kwargs):