In [None]:
import torch
import torch.nn as nn

en = [['<b>', 'i', ' ', 'like', ' ', 'you', '<e>'], ['<b>', 'i', ' ', 'hate', ' ', 'you', '<e>'],
      ['<b>', 'i', ' ', 'love', ' ', 'you', '<e>'],
      ['<b>', 'he', ' ', 'like', ' ', 'you', '<e>']]
zh = [['<b>', '我', '喜欢', '你', '<e>'], ['<b>', '我', '讨厌', '你', '<e>'], ['<b>', '我', '爱', '你', '<e>'],
      ['<b>', '他', '喜欢', '你', '<e>']]

en_vocab_i2t = ['i', 'like', 'you', 'hate', 'he', ' ', 'love', '<b>', '<e>']
en_vocab_t2i = {'i': 0, 'like': 1, 'you': 2, 'hate': 3, 'he': 4, ' ': 5, 'love': 6
    , '<b>': 7, '<e>': 8}
zh_vocab_i2t = ['我', '喜欢', '你', '讨厌', '爱', ' ', '他', '<b>', '<e>']
zh_vocab_t2i = {'我': 0, '喜欢': 1, '你': 2, '讨厌': 3, '爱': 4, ' ': 5, '他': 6
    , '<b>': 7, '<e>': 8}


def process(en, zh):
    en_idx = [[en_vocab_t2i[token] for token in line] for line in en]
    zh_idx = [[zh_vocab_t2i[token] for token in line] for line in zh]
    return torch.tensor(en_idx), torch.tensor(zh_idx)


en_idx, zh_idx = process(en, zh)
num_hiddens = 32
num_layers = 2
embed_size = 4
zh_vocab_size = len(zh_vocab_i2t)
en_vocab_size = len(en_vocab_i2t)


class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers)

    def forward(self, input):
        X = self.embed(input)
        X = X.permute(1, 0, 2)
        out_put, (last_hiddens, last_cells) = self.rnn(X)
        return out_put, last_hiddens, last_cells


class Seq2seqDecoder(nn.Module):
    # vocab_size 为 tgt的
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(num_hiddens + embed_size, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)


    def forward(self, input, enc_hiddens, enc_cells):
        X = self.embed(input)
        X = X.permute(1, 0, 2)
        context = enc_hiddens[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.concat((X, context), 2)
        output, (last_hiddens, last_cells) = self.rnn(X_and_context, (enc_hiddens, enc_cells))
        output = self.dense(output).permute(1, 0, 2)
        return output, last_hiddens, last_cells


class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""

    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X):
        enc_outputs, enc_h, enc_c = self.encoder(enc_X)
        dec_output,dec_h,dec_c  = self.decoder(dec_X, enc_h, enc_c)
        return dec_output,dec_h,dec_c


In [None]:
encoder = Seq2SeqEncoder(zh_vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2seqDecoder(en_vocab_size, embed_size, num_hiddens, num_layers)

net = EncoderDecoder(encoder, decoder)
y_hat, _, _ = net(zh_idx, en_idx)
y_hat.shape

In [None]:
en_idx.shape

In [None]:
lr = 1e-3

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 300


def train():
    for e in range(num_epochs):
        optimizer.zero_grad()
        y_hat, _, _ = net(zh_idx, en_idx)

        loss = loss_fn(y_hat.permute(0,2,1), en_idx)
        loss.backward()
        optimizer.step()
        print('epoch:{},loss:{}'.format(e, loss))


train()




In [None]:

tgt = torch.tensor(7).unsqueeze(0).unsqueeze(0)
x = ['<b>', '我', '讨厌','我','<e>']
x = torch.tensor([zh_vocab_t2i[token] for token in x]).unsqueeze(0)
x,tgt

In [None]:
enc_x,h,c = encoder(x)
res = ""
while(1):
    output,h,c = decoder(tgt,h,c)
    pred = torch.argmax(output,dim=2)
    res += en_vocab_i2t[pred]
    if pred == en_vocab_t2i['<e>']:
        break


In [None]:
res