In [73]:
import torch
import unidecode
import numpy as np

In [93]:
lines = open('spa-eng/spa.txt', encoding='utf-8').read().strip().split('\n')
pairs = [[s for s in l.split('\t')] for l in lines]

# convert to lower case and drop all non-alphabetic characters
english, spanish = zip(*pairs)
english = [s.lower().strip() for s in english]
spanish = [s.lower().strip() for s in spanish]
english = [''.join(c for c in s if c.isalpha() or c == ' ') for s in english]
spanish = [''.join(c for c in s if c.isalpha() or c == ' ') for s in spanish]
#convert diacritical to non-diacritical
english = [unidecode.unidecode(s) for s in english]
spanish = [unidecode.unidecode(s) for s in spanish]

# randomly take only 0.1% of the data
np.random.seed(0)
idx = np.random.choice(len(english), size=int(len(english)*0.05), replace=False)
english = [english[i] for i in idx]
spanish = [spanish[i] for i in idx]

In [94]:
len(english)

5948

In [95]:
# split train dev test
train_size = 0.6
dev_size = 0.2
test_size = 0.2
assert train_size + dev_size + test_size == 1
train_size = int(train_size * len(english))
dev_size = int(dev_size * len(english))
test_size = len(english) - train_size - dev_size
# use random indices
indices = np.random.permutation(len(english))
train_indices = indices[:train_size]
dev_indices = indices[train_size:train_size+dev_size]
test_indices = indices[train_size+dev_size:]
train_english = [english[i] for i in train_indices]
train_spanish = [spanish[i] for i in train_indices]
dev_english = [english[i] for i in dev_indices]
dev_spanish = [spanish[i] for i in dev_indices]
test_english = [english[i] for i in test_indices]
test_spanish = [spanish[i] for i in test_indices]

In [96]:
# for all spanish sentences, add an bos and eos token
train_spanish = ['<bos> ' + s + ' <eos>' for s in train_spanish]
dev_spanish = ['<bos> ' + s + ' <eos>' for s in dev_spanish]
test_spanish = ['<bos> ' + s + ' <eos>' for s in test_spanish]

In [97]:
# create vocabulary
english_vocab = set()
spanish_vocab = set()
for s in train_english:
    english_vocab.update(s.split())
for s in train_spanish:
    spanish_vocab.update(s.split())

# convert to dictionary
english_vocab = {w: i for i, w in enumerate(english_vocab)}
spanish_vocab = {w: i for i, w in enumerate(spanish_vocab)}

In [98]:
# create RNN class encoding an english sentence
class RNN_encoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNN_encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.rnn = torch.nn.RNN(hidden_size, hidden_size)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.rnn(embedded, hidden)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)
    
# create RNN class decoding the provided hidden state from the encoder
class RNN_decoder(torch.nn.Module):
    def __init__(self, hidden_size, output_size):
        super(RNN_decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(output_size, hidden_size)
        self.rnn = torch.nn.RNN(hidden_size, hidden_size)
        self.out = torch.nn.Linear(hidden_size, output_size)
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.rnn(embedded, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)
    
# create encoder and decoder
hidden_size = 256
encoder = RNN_encoder(len(english_vocab), hidden_size)
decoder = RNN_decoder(hidden_size, len(spanish_vocab))


In [104]:
# combine parameters of encoder and decoder
encoder_decoder_params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.SGD(encoder_decoder_params, lr=0.01)
loss_function = torch.nn.CrossEntropyLoss()


In [105]:
# train
epochs = 10
for epoch in range(epochs):
    epoch_loss = 0
    for i in range(len(train_english)):
        # encode english sentence
        encoder_hidden = encoder.initHidden()
        for word in train_english[i].split():
            word_tensor = torch.tensor([english_vocab[word]], dtype=torch.long)
            _, encoder_hidden = encoder(word_tensor, encoder_hidden)

        # decode spanish sentence
        decoder_input = torch.tensor([spanish_vocab['<bos>']], dtype=torch.long)
        decoder_hidden = encoder_hidden
        loss = 0
        for word in train_spanish[i].split():
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            loss += loss_function(decoder_output, torch.tensor([spanish_vocab[word]], dtype=torch.long))
            decoder_input = torch.tensor([spanish_vocab[word]], dtype=torch.long)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        if i%100 == 0:
            #print current loss
            print('Item: {}, Loss: {}'.format(i, loss.item()))

    print('Epoch: {}, Loss: {}'.format(epoch, epoch_loss / len(train_english)))
    

Item: 0, Loss: 66.22224426269531
Item: 100, Loss: 12805.17578125
Item: 200, Loss: 426377.4375
Item: 300, Loss: 402265.96875
Item: 400, Loss: 1201491.25
Item: 500, Loss: 119117.9453125
Item: 600, Loss: 473794.125
Item: 700, Loss: 435449.375
Item: 800, Loss: 509830.96875
Item: 900, Loss: 580272.125
Item: 1000, Loss: 163374.453125
Item: 1100, Loss: 119778.015625
Item: 1200, Loss: 2187359.5
Item: 1300, Loss: 1984573.25
Item: 1400, Loss: 2067119.125
Item: 1500, Loss: 1831124.0
Item: 1600, Loss: 2912815.5
Item: 1700, Loss: 279714.3125
Item: 1800, Loss: 942941.125
Item: 1900, Loss: 186260.34375
Item: 2000, Loss: 1890290.5
Item: 2100, Loss: 1726339.75
Item: 2200, Loss: 1519154.125
Item: 2300, Loss: 3470031.5
Item: 2400, Loss: 1281855.5
Item: 2500, Loss: 173352.546875
Item: 2600, Loss: 424814.625
Item: 2700, Loss: 2201312.25
Item: 2800, Loss: 587630.0
Item: 2900, Loss: 1726781.0
Item: 3000, Loss: 234485.1875
Item: 3100, Loss: 115037.7421875
Item: 3200, Loss: 7382780.0
Item: 3300, Loss: 1424121.

KeyboardInterrupt: 