In [8]:
import torch
import torch.nn as nn

In [24]:
with open("cmn.txt", encoding='utf8') as f:
    lines = f.readlines()

sep = ".!?,。？！，"
lines = ["".join([char if char not in sep else " " + char + " " for char in line]) for line in lines]
print(lines[0])
lines = [line.split("\t")[:2] for line in lines]
lines = [(line_pair[0].split(), line_pair[1].split()) for line_pair in lines]

source_sequence = []
target_sequence = []
for line in lines:
    source_sequence.append(line[0])
    target = []
    for seq in line[1]:
        if len(seq) == 1:
            target.append(seq)
        else:
            target.extend([char for char in seq])
    target_sequence.append(target)
print(source_sequence[10], target_sequence[10])

source_dict = {}
target_dict = {}

for sequence in source_sequence:
    for word in sequence:
        if source_dict.get(word, -1) == -1:
            source_dict[word] = 0
        else:
            source_dict[word] += 1
for sequence in target_sequence:
    for word in sequence:
        if target_dict.get(word, -1) == -1:
            target_dict[word] = 0
        else:
            target_dict[word] += 1

count = 0
for k in target_dict:
    print(k, target_dict[k])
    count += 1
    if count > 10:
        break

Hi . 	嗨 。 	CC-BY 2 . 0 (France) Attribution: tatoeba . org #538123 (CM) & #891077 (Martha)

['Oh', 'no', '!'] ['不', '会', '吧', '。']
嗨 4
。 23548
你 5830
好 1233
用 437
跑 101
的 8091
住 285
手 311
！ 266
等 232


In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        raise NotImplementedError

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
    def init_state(self, state):
        raise NotImplementedError
    def forward(self, x, state):
        raise NotImplementedError

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
    def forward(self, x, y):
        encoder_output = self.encoder_(x)
        state = self.decoder_.init_state(encoder_output)
        decoder_output = self.decoder_(y, state)
        return decoder_output
    # def predict(self, x):
    #     NotImplementedError

In [4]:
class GRUEncoder(Encoder):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.hidden_dim = embedding_dim * 2
        self.word2vec = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=self.hidden_dim, num_layers=3, bias=True, dropout=0.5)
    def forward(self, x):
        hidden_state = torch.zeros((x.shape[0], self.hidden_dim))
        embedding = self.word2vec(x)
        output, hidden = self.gru(embedding, hidden_state)
        return hidden

class GRUDecoder(Decoder):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.hidden_dim = embedding_dim * 2 
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=self.hidden_dim, num_layers=2, bias=True, dropout=0.5)
        self.linear = nn.Linear(self.hidden_dim, vocab_size)    
    def init_state(self, state):
        return state
    def forward(self, y, state):
        output, hidden = self.gru(y, state)
        # output.shape = sequence, batch, embedding_dim
        one_hot = self.linear(output)
        return torch.sigmoid(one_hot)