In [21]:
from seq2seq import train_epoch, get_dataloader, EncoderRNN, AttnDecoderRNN, prepareData
import time
import torch
from torch import optim
import torch.nn as nn
import math
import time

In [31]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [3]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [4]:
hidden_size = 128
batch_size = 32

input_lang, output_lang, train_dataloader = get_dataloader(batch_size)

Read 135842 sentence pairs
Trimmed to 11362 sentence pairs
Counting words...
Counted words:
fra 4596
eng 2989


In [5]:
encoder = EncoderRNN(input_lang.n_words, hidden_size)
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words)

In [6]:
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=1, plot_every=5):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

In [7]:
train(train_dataloader, encoder, decoder, 200)

0m 16s (- 55m 24s) (1 0%) 2.5192
0m 33s (- 55m 8s) (2 1%) 1.6800
0m 48s (- 53m 37s) (3 1%) 1.4019
1m 4s (- 52m 38s) (4 2%) 1.2008
1m 19s (- 51m 39s) (5 2%) 1.0373
1m 35s (- 51m 14s) (6 3%) 0.8991
1m 51s (- 51m 24s) (7 3%) 0.7845
2m 8s (- 51m 29s) (8 4%) 0.6863
2m 23s (- 50m 42s) (9 4%) 0.6013
2m 38s (- 50m 4s) (10 5%) 0.5257
2m 52s (- 49m 32s) (11 5%) 0.4607
3m 7s (- 49m 3s) (12 6%) 0.4064
3m 22s (- 48m 35s) (13 6%) 0.3582
3m 37s (- 48m 13s) (14 7%) 0.3170
3m 52s (- 47m 48s) (15 7%) 0.2822
4m 7s (- 47m 25s) (16 8%) 0.2525
4m 22s (- 47m 1s) (17 8%) 0.2220
4m 37s (- 46m 40s) (18 9%) 0.1990
4m 51s (- 46m 18s) (19 9%) 0.1786
5m 6s (- 45m 59s) (20 10%) 0.1616
5m 21s (- 45m 39s) (21 10%) 0.1458
5m 36s (- 45m 20s) (22 11%) 0.1337
5m 50s (- 45m 0s) (23 11%) 0.1204
6m 5s (- 44m 41s) (24 12%) 0.1134
6m 20s (- 44m 23s) (25 12%) 0.1060
6m 34s (- 44m 3s) (26 13%) 0.0993
6m 49s (- 43m 45s) (27 13%) 0.0957
7m 4s (- 43m 28s) (28 14%) 0.0845
7m 19s (- 43m 11s) (29 14%) 0.0826
7m 34s (- 42m 54s) (30 15%

In [18]:
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

In [28]:
import random

input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

SOS_token = 0
EOS_token = 1

def indexesFromSentence(lang, sentence):
	return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
	indexes = indexesFromSentence(lang, sentence)
	indexes.append(EOS_token)
	return torch.tensor(indexes, dtype=torch.long).view(1,-1)

def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

Read 135842 sentence pairs
Trimmed to 11362 sentence pairs
Counting words...
Counted words:
fra 4596
eng 2989
['nous sommes amoureuses', 'we re in love']


In [29]:
encoder.eval()
decoder.eval()
evaluateRandomly(encoder, decoder)

> je ne vais pas autoriser ca
= i m not going to allow that
< i m not going to allow that anymore <EOS>

> je suis implique
= i m involved
< i m involved <EOS>

> je cherche un sac pour ma femme
= i m looking for a bag for my wife
< i m looking for a bag for my wife <EOS>

> c est quelqu un de bien
= he s a good person
< he s a good person <EOS>

> je suis heureux que vous l ayez aime
= i m happy you liked it
< i m happy you liked it <EOS>

> vous etes trop tendue
= you re too tense
< you re too tense <EOS>

> tu vas bien
= you re all right
< you re all right i fine <EOS>

> ils ne sont pas si mauvais
= they re not so bad
< they re not so bad in no good <EOS>

> je vais voir mary cet apres midi
= i am seeing mary this afternoon
< i am seeing mary this afternoon <EOS>

> je suis un homme
= i am a man
< i am a man man <EOS>



In [33]:
torch.save(encoder.state_dict(), 'seq2seqEncoder')
torch.save(decoder.state_dict(), 'seq2seqDecoder')

In [32]:
def showAttention(input_sentence, output_words, attentions):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence.split(' ') +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])


evaluateAndShowAttention('il n est pas aussi grand que son pere')

evaluateAndShowAttention('je suis trop fatigue pour conduire')

evaluateAndShowAttention('je suis desole si c est une question idiote')

evaluateAndShowAttention('je suis reellement fiere de vous')

input = il n est pas aussi grand que son pere
output = he is not as tall as his father <EOS>
input = je suis trop fatigue pour conduire
output = i m too tired to drive <EOS>
input = je suis desole si c est une question idiote
output = i m sorry if this is a stupid question <EOS>
input = je suis reellement fiere de vous
output = i m really proud of you <EOS>


  ax.set_xticklabels([''] + input_sentence.split(' ') +
  ax.set_yticklabels([''] + output_words)
  plt.show()
  ax.set_xticklabels([''] + input_sentence.split(' ') +
  ax.set_yticklabels([''] + output_words)
  plt.show()
  ax.set_xticklabels([''] + input_sentence.split(' ') +
  ax.set_yticklabels([''] + output_words)
  plt.show()
  ax.set_xticklabels([''] + input_sentence.split(' ') +
  ax.set_yticklabels([''] + output_words)
  plt.show()
