In [7]:
import torch
import torch.nn as nn

In [17]:
with open("cmn.txt", encoding='utf8') as f:
    lines = f.readlines()

lines = [line.split("\t")[:2] for line in lines]
sep = ".!?,。？！，"
lines = [(line_pair[0].split(), line_pair[1].split()) for line_pair in lines]
print(lines[0])

(['Hi.'], ['嗨。'])


In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        raise NotImplementedError

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
    def init_state(self, state):
        raise NotImplementedError
    def forward(self, x, state):
        raise NotImplementedError

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
    def forward(self, x, y):
        encoder_output = self.encoder_(x)
        state = self.decoder_.init_state(encoder_output)
        decoder_output = self.decoder_(y, state)
        return decoder_output
    # def predict(self, x):
    #     NotImplementedError

In [5]:
class GRUEncoder(Encoder):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.hidden_dim = embedding_dim * 2
        self.word2vec = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=self.hidden_dim, num_layers=3, bias=True, dropout=0.5)
    def forward(self, x):
        hidden_state = torch.zeros((x.shape[0], self.hidden_dim))
        embedding = self.word2vec(x)
        output, hidden = self.gru(embedding, hidden_state)
        return hidden

class GRUDecoder(Decoder):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.hidden_dim = embedding_dim * 2 
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=self.hidden_dim, num_layers=2, bias=True, dropout=0.5)
        self.linear = nn.Linear(self.hidden_dim, vocab_size)    
    def init_state(self, state):
        return state
    def forward(self, y, state):
        output, hidden = self.gru(y, state)
        # output.shape = sequence, batch, embedding_dim
        one_hot = self.linear(output)
        return torch.sigmoid(one_hot)