In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [67]:
import numpy as np
import pandas as pd
import torch
from torch.jit import script, trace
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import csv
import random
from pathlib import Path
import re
import codecs
from io import open
import itertools
import math
import unicodedata

In [3]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
Path.ls = lambda x: list(x.iterdir())

In [11]:
!ls ../input

README.txt		       movie_conversations.txt	  raw_script_urls.txt
chameleons.pdf		       movie_lines.txt
movie_characters_metadata.txt  movie_titles_metadata.txt


In [9]:
PATH = Path('../input')
def printLines(file, n=10):
    with open(file, 'rb') as f:
        lines = f.readlines()
    for line in lines[:n]: print(line)
printLines(PATH/'movie_lines.txt')

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


# Some preprocessing

### loadLines() function splits each line of file into dictionary of (LineID, CharacterID, MovieID, Character, Text)

In [60]:
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            lineObj = {}
            for i, field in enumerate(fields): lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines

### loadConversations() groups the lines according to the conversations that are happening in conversations.txt

In [61]:
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            convObj = {}
            for i, field in enumerate(fields): convObj[field] = values[i]
            lineIDs = eval(convObj['utteranceIDs'])
            convObj['lines'] = []
            for lineID in lineIDs: convObj['lines'].append(lines[lineID])
            conversations.append(convObj)
    return conversations

### extractSentencePairs() extracts the sentence pairs used in the conversation from the previously constructed conversations list

In [62]:
def extractSentencePair(conversations):
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation['lines'])-1):
            inputLine = conversation['lines'][i]['text'].strip()
            targetLine = conversation['lines'][i+1]['text'].strip()
            if inputLine and targetLine: qa_pairs.append([inputLine, targetLine])
    return qa_pairs

### Calling these three functions and save them as formatted_lines.txt

In [63]:
file = 'formatted_lines.txt'
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))
LINE_FIELDS = ['lineID', 'characterID', 'movieID', 'character', 'text']
CONV_FIELDS = ['character1ID', 'character2ID', 'movieID', 'utteranceIDs']

print('Preprocessing...\n')
lines = loadLines(PATH/'movie_lines.txt', LINE_FIELDS)

print('Loading Conversations...\n')
convs = loadConversations(PATH/'movie_conversations.txt', lines, CONV_FIELDS)

print('Done!!!')

Preprocessing...

Loading Conversations...

Done!!!


### Creating a new file to store the sentence pairs

In [65]:
print('Writing new file to get dialogues...\n')
with open(file, 'w', encoding='utf-8') as out:
    writer = csv.writer(out, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePair(convs): writer.writerow(pair)

print('Sample lines from file\n')
printLines(file)

Writing new file to get dialogues...

Sample lines from file

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get 

# Loading and trimming data

Now we have to create a vocabulary for our model. Because our model can only deal with numbers and not text, we will convert text to numbers and pass that to model to get results. We will create mapping for each word to a number

In [72]:
PAD = 0
BOS = 1
EOS = 2
class Vocab:
    def __init__(self):
        self.trimmed = False
        self.word2idx = {}
        self.wordcount = {}
        self.idx2word = {PAD: 'PAD', BOS: 'BOS', EOS: 'EOS'}
        self.num_words = 3
    
    def addSentence(self, sentence):
        for word in sentence.split(' '): self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.num_words
            self.wordcount[word] = 1
            self.idx2word[self.num_words] = word
            self.num_words+=1
        else: self.wordcount[word]+=1
    
    #trim words with count lower than a particular threshold
    def trim(self, min_freq):
        if self.trimmed: return
        keep_words = []
        for k, v in self.wordcount.items():
            if v>=min_freq: keep_words.append(k)
        print(f'Keep {len(keep_words)} words out of {len(self.wordcount)} words')
        self.word2idx = {}
        self.wordcount = {}
        self.idx2word = {PAD: 'PAD', BOS: 'BOS', EOS: 'EOS'}
        self.num_words = 3
        for word in keep_words: self.addWord(word)

Convert the sentence to ASCII first, then to lowercase, trim all the punctuations. We consider sentences which have length under the MAX_LENGTH threshold for ease of calculaton. 

In [73]:
MAX_LENGTH = 10

def unicodeToASCII(sentence):
    return ''.join(
        c for c in unicodedata.normalize('NFD', sentence) 
        if unicodedata.category(c)!='Mn'
    )

def normalizeString(s):
    s = unicodeToASCII(s)
    s = re.sub(r'([.!?])', r' \1', s)
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    s = re.sub(r'\s+', r' ', s).strip()
    return s

def readVocs(fileName):
    print('Reading lines')
    lines = open(fileName, encoding='utf-8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Vocab()
    return voc, pairs

def filterPair(p):
    return len(p[0].split(' '))<MAX_LENGTH and len(p[1].split(' '))<MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(fileName):
    print('Preparing Training Data....')
    voc, pairs = readVocs(fileName)
    print('Reading Sentence Pairs...')
    pairs = filterPairs(pairs)
    print(f'Selected all the pairs with sentence length {MAX_LENGTH}')
    print('Counting words')
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print('Number of words in vocabulary:', voc.num_words)
    return voc, pairs

voc, pairs = loadPrepareData(file)
print('\nPairs')
for pair in pairs[:10]: print(pair)

Preparing Training Data....
Reading lines
Reading Sentence Pairs...
Selected all the pairs with sentence length 10
Counting words
Number of words in vocabulary: 22448

Pairs
['There .', 'Where ?']
['You have my word . As a gentleman', 'You re sweet .']
['Hi .', 'Looks like things worked out tonight huh ?']
['You know Chastity ?', 'I believe we share an art instructor']
['Have fun tonight ?', 'Tons']
['Well no . . .', 'Then that s all you had to say .']
['Then that s all you had to say .', 'But']
['But', 'You always been this selfish ?']
['do you listen to this crap ?', 'What crap ?']
['What good stuff ?', 'The real you .']


Trim rarely used words to achieve faster convergence. This will lead to removing some pairs too.

In [74]:
MIN_COUNT = 3
def trimRare(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_inp, keep_out = True, True
        for word in input_sentence.split(' '):
            if word not in voc.word2idx:
                keep_inp = False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2idx:
                keep_out = False
                break
        if keep_inp and keep_out: keep_pairs.append(pair)
    print(f'Trimmed number of pairs from {len(pairs)} to {len(keep_pairs)}')
    return keep_pairs

pairs = trimRare(voc, pairs, MIN_COUNT)

Keep 9090 words out of 22445 words
Trimmed number of pairs from 64271 to 50411


In [75]:
pairs[:10]

[['There .', 'Where ?'],
 ['You have my word . As a gentleman', 'You re sweet .'],
 ['Hi .', 'Looks like things worked out tonight huh ?'],
 ['Have fun tonight ?', 'Tons'],
 ['Well no . . .', 'Then that s all you had to say .'],
 ['Then that s all you had to say .', 'But'],
 ['But', 'You always been this selfish ?'],
 ['do you listen to this crap ?', 'What crap ?'],
 ['What good stuff ?', 'The real you .'],
 ['Wow', 'Let s go .']]

# Convert the data into batches
The data we have does not take advantage of GPU capabilites. We have to convert the data to tensor. So we convert the data into tensors of size (max_length x batch_size) for each time step, it will take a batch of input words as input and pass it into network. We do this using itertools.zip_longest() function (See its docs)

In [76]:
def indexFromSentence(voc, sentence): 
    return [voc.word2idx[word] for word in sentence.split(' ')] + [EOS]

def zeroPadding(l, fillvalue=PAD):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

#shows which indices of the target matrix are padded i.e are not there but just added to pad shorter sentences
def binaryMatrix(l, value=PAD):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq: 
            if token==PAD: m[i].append(0)
            else: m[i].append(1)
    return m

def inputVar(l, voc):
    indexes_batch = [indexFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

def outputVar(l, voc):
    indexes_batch = [indexFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(idx) for idx in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(' ')), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
inp, length, out, mask, max_target_len = batches
print('Input Variable:', inp)
print('lengths:', length)
print('Target Variable:', out)
print('Mask:', mask)
print('Maximum Target Length:', max_target_len)

Input Variable: tensor([[ 139,  128,   71,  134, 7710],
        [ 280,    4,   40,  466,    2],
        [ 125,  130,  500,   87,    0],
        [ 226,  459,  137,  430,    0],
        [ 666,  661, 5766,    4,    0],
        [  42,  116, 1821,    2,    0],
        [  80,   80,    6,    0,    0],
        [  23, 1874,    2,    0,    0],
        [   6,    6,    0,    0,    0],
        [   2,    2,    0,    0,    0]])
lengths: tensor([10, 10,  8,  6,  2])
Target Variable: tensor([[ 133, 1138,  358,  160, 7710],
        [1036,  114,    4,   49,    2],
        [ 648,    4,    2,   40,    0],
        [   4,    2,    0,   24,    0],
        [   2,    0,    0,    6,    0],
        [   0,    0,    0,    2,    0]])
Mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 0, 1, 0],
        [1, 0, 0, 1, 0],
        [0, 0, 0, 1, 0]], dtype=torch.uint8)
Maximum Target Length: 6


# Define the models
We will use the seq2seq model with encoder and decoder. It will take a variable length sequence as input and return a variable length sequence as an output.

We will use two RNNs(specifically GRU), one as encoder and other as decoder. 

![seq2seq model block diagram](https://jeddy92.github.io/images/ts_intro/seq2seq_ts.png)

# Encoder
We will use a bidirectional RNN as the encoder. It iterates through input sentence one token at a time and at each step outputs the output vector and a hidden state vector. Hidden state is passed to next time step and output vector is recorded. Thus encoder transforms the inputs it saw to some context which decoder uses to create output. The variant of GRU we are using can be considered as two independent GRUs: 

1. In which input is fed in normal sequence order
2. In which input is fed in reverse order. 

Thus, the output is sum of both past and future contexts

![Bidirectional RNN](https://pytorch.org/tutorials/_images/RNN-bidirectional.png)

We know that the RNN will accept a word of size hidden size. For this purpose, we use an embedding layer which maps each word to a vector of size hidden size. This embedding layer will be trained or we can also use glove or fastText embeddings.

Also, while passing the padded batch, we must pack and unpack the padding around RNN. So, we use `nn.utils.rnn.pack_padded_sequence` and `nn.utils.rnn.pad_packed_sequence`

Here are steps to use:

1. Convert word indices to embeddings
2. Pack the batch to RNN
3. Forward pass
4. Unpack padding
5. Sum bidirectional outputs
6. Return outputs and final states

In [155]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers 
        self.embedding = embedding
        self.hidden_size = hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, 
                          dropout=(0 if n_layers==1 else dropout), bidirectional=True)
    
    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        return outputs, hidden

# Decoder

Decoder generates the output using the context vector created from the encoder. It continues generating text until it encounters an EOS token. One problem, though, with the decoder is that it can have information loss. To tackle this, we use attention.

In attention, decoder pays attention to some parts of input sequence rather seeing the whole input sequence. This is somewhat human like. We pay attention to not all words of a sentence but rather some words, which are important.

Attention is calculated using decoder's current hidden state and encoder's outputs. They have same shape as input sequence, so they can be multiplied with encoder outputs and we get a weighted sum which indicates which output to pay attention to. 

![Attention in action](https://pytorch.org/tutorials/_images/attn2.png)

We will use 'Global Attention'. What this type of attention does is that we consider all the encoder hidden states. In this type of attention, we need decoder's current output only. We use score functions to calculate attention energies between encoder outputs and decoder outputs.

<img src='https://pytorch.org/tutorials/_images/scores.png' alt='Score functions' width=300 height=300>
where $h_t$ = current decoder state, $W_a$ = weights of attention layer which have to be learnt and $\bar h_s$ = all encoder states


Global attention is shown in the figure. Output of this layer is softmax normalized tensor of shape (batch_size, 1, max_length)
<img src='https://pytorch.org/tutorials/_images/global_attn.png' alt='Global Attention' width='300' height='300'/>

In [156]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['general', 'dot', 'concat']:
            raise ValueError(self.method, 'is not an appropriate attention model')
        self.hidden_size = hidden_size
        if self.method=='general': self.attn = nn.Linear(self.hidden_size, self.hidden_size)
        elif self.method=='concat':
            self.attn = nn.Linear(self.hidden_size*2, self.hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))
        
    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)
        
    def forward(self, hidden, encoder_output):
        if self.method=='general': attn_energies = self.general_score(hidden, encoder_output)
        elif self.method=='dot': attn_energies = self.dot_score(hidden, encoder_output)
        elif self.method=='concat': attn_energies = self.concat_score(hidden, encoder_output)
        attn_energies = attn_energies.t()
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

**Steps to create the decoder**
1. Get embeddings of the current input
2. Forward pass through GRU
3. Calculate attention weights from the current GRU output
4. Multiply attention weights to encoder outputs to get weighted sum context vector
5. Concatenate weighted context vector and GRU output
6. Predict next word
7. Return output and final hidden state

In [157]:
class DecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(DecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers==1 else self.dropout))
        self.concat = nn.Linear(hidden_size*2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(attn_model, hidden_size)
    
    def forward(self, input_step, last_hidden, encoder_output):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_output)
        context = attn_weights.bmm(encoder_output.transpose(0, 1))
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden

# Training the model

# Masked Loss
Since we are dealing with padded sequences, we cannot consider all elements of tensor when calculating loss. We define maskedNLLLoss to calculate loss based on decoder's output tensor, target tensor and mask tensor describing the padding. It calculates NLLLoss of elements which have 1 in mask

In [158]:
def maskedNLLLoss(inp, target, mask):
    total = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, total.item()

# Training
We use some tricks while training:
1. **Teacher forcing** - With some probability, we use current target word as decoder's next input rather than decoder's guessed word. It leads to more efficiency. But we have to set the probability carefully otherwise it will lead to instability during inference.

2. **Gradient Clipping** - To counter exploding gradient problem by clipping the gradients to some maximum value.

![Gradient Clipping](https://pytorch.org/tutorials/_images/grad_clip.png)

**Steps**
1. Forward pass through encoder
2. Initialize decoder inputs as BOS and hidden state as encoder's final hidden state
3. Forward pass input batch sequence through decoder one step at time
4. If teacher forcing: set next decoder input as current target else as current decoder output
5. Calculate and accumulate loss
6. Perform backprop
7. Clip gradients
8. Update encoder and decoder params

In [159]:
def train(input_var, lengths, target_var, mask,
          max_target_len, encoder, decoder, embedding, 
          encoder_optimizer, decoder_optimizer, batch_size,
          clip, max_length=MAX_LENGTH):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    input_var = input_var.to(device)
    lengths = lengths.to(device)
    target_var = target_var.to(device)
    mask = mask.to(device)
    loss = 0
    print_losses = []
    n_totals = 0
    encoder_out, encoder_hidden = encoder(input_var, lengths)
    decoder_in = torch.LongTensor([[BOS for _ in range(batch_size)]])
    decoder_in = decoder_in.to(device)
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    use_teacher_forcing = True if random.random()<teacher_forcing_ratio else False
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_out, decoder_hidden = decoder(decoder_in, decoder_hidden, encoder_out)
            decoder_in = target_var[t].view(1, -1)
            mask_loss, nTotal = maskedNLLLoss(decoder_out, target_var[t], mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
    else: 
        for t in range(max_target_len):
            decoder_out, decoder_hidden = decoder(decoder_in, decoder_hidden, encoder_out)
            _, topi = decoder_out.topk(1)
            decoder_in = torch.LongTensor([[topi[1][0] for i in range(batch_size)]])
            decoder_in = decoder_in.to(device)
            mask_loss, nTotal = maskedNLLLoss(decoder_out, target_var[t], mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
    loss.backward()
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    encoder_optimizer.step()
    decoder_optimizer.step()
    return sum(print_losses)/n_totals

# Training Iterations
Run iterations of training with models, optimizers, data, etc.

In [160]:
def trainIters(voc, pairs, encoder, decoder, encoder_opt, 
               decoder_opt, embedding, encoder_n_layers, decoder_n_layers,
               n_iterations, batch_size, print_every, clip):
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)]) 
                        for _ in range(n_iterations)]
    print('Initializing...')
    start_iteration = 1
    print_loss = 0        
    print('Training...')
    for iteration in range(start_iteration, n_iterations+1):
        training_batch = training_batches[iteration-1]
        inp, lens, targ, mask, max_targ_len = training_batch
        loss = train(inp, lens, targ, mask, max_targ_len, 
                     encoder, decoder, embedding, encoder_opt, 
                     decoder_opt, batch_size, clip)
        print_loss+=loss
        if iteration%print_every==0:
            print_loss_avg = print_loss / print_every
            print(f'Iteration {iteration}, Percentage: {iteration/n_iterations * 100}, Average Loss: {print_loss_avg}')
            print_loss = 0

# Evaluation

We have to evaluate the bot too

**Greedy Decoding**
During training, when we don't use teacher forcing, this method is applied. Simply choose word with highest softmax value. It is optimal for single time step

**Steps**
1. Forward pass through encoder
2. Encoder's final hidden state is first hidden state of decoder
3. Decoder's first word is BOS token
4. Initialize tensors to append decoded words
5. Iteratively decode one at a time
    5.1. Forward pass through decoder
    5.2. Obtain most likely word token and its score
    5.3. Record token and score
    5.4. Prepare current token to be next input
6. Return word tokens and scores

In [167]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, input_seq, input_len, max_len):
        encoder_out, encoder_hidden = self.encoder(input_seq, input_len)
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * BOS
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_len):
            decoder_out, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_out)
            decoder_scores, decoder_input = torch.max(decoder_out, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

In [168]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    index_batch = [indexFromSentence(voc, sentence)]
    lengths = torch.tensor([len(indxs) for indxs in index_batch])
    input_batch = torch.LongTensor(index_batch).transpose(0, 1)
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [voc.idx2word[token.item()] for token in tokens]
    return decoded_words

def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            input_sentence = input('> ')
            if input_sentence=='q' or input_sentence=='quit': break
            input_sentence = normalizeString(input_sentence)
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            output_words[:] = [x for x in output_words if not (x=='EOS' or x=='PAD')]
            print('Bot:', ' '.join(output_words))
        except KeyError: print('Unknown word for the bot')

In [163]:
attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

print('Building encoder and decoder....')
embedding = nn.Embedding(voc.num_words, hidden_size)
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = DecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built... Ready to go!')

Building encoder and decoder....
Models built... Ready to go!


In [164]:
clip = 50.
teacher_forcing_ratio = 1.
lr = 0.0001
decoder_learning_ratio = 5.
n_iterations = 4000
print_every = 1

encoder.train()
decoder.train()

print('Building optimizers....')
encoder_opt = optim.Adam(encoder.parameters(), lr=lr)
decoder_opt = optim.Adam(decoder.parameters(), lr=lr*decoder_learning_ratio)

print('Start Training...')
trainIters(voc, pairs, encoder, decoder, encoder_opt, decoder_opt, embedding, 
           encoder_n_layers, decoder_n_layers, n_iterations, batch_size, print_every, clip)

Building optimizers....
Start Training...
Initializing...
Training...
Iteration 1, Percentage: 0.025, Average Loss: 9.113757581853156
Iteration 2, Percentage: 0.05, Average Loss: 9.002455254946307
Iteration 3, Percentage: 0.075, Average Loss: 8.83112385626615
Iteration 4, Percentage: 0.1, Average Loss: 8.539997522077412
Iteration 5, Percentage: 0.125, Average Loss: 8.105494439160383
Iteration 6, Percentage: 0.15, Average Loss: 7.6242798905384
Iteration 7, Percentage: 0.17500000000000002, Average Loss: 7.115524013837178
Iteration 8, Percentage: 0.2, Average Loss: 7.071056206081365
Iteration 9, Percentage: 0.22499999999999998, Average Loss: 6.980994437026511
Iteration 10, Percentage: 0.25, Average Loss: 6.527976467177786
Iteration 11, Percentage: 0.27499999999999997, Average Loss: 6.4013454230616285
Iteration 12, Percentage: 0.3, Average Loss: 6.137634895235225
Iteration 13, Percentage: 0.325, Average Loss: 5.825482413591936
Iteration 14, Percentage: 0.35000000000000003, Average Loss: 5.

Iteration 124, Percentage: 3.1, Average Loss: 4.591095066181599
Iteration 125, Percentage: 3.125, Average Loss: 4.374129741177928
Iteration 126, Percentage: 3.15, Average Loss: 4.372848302816837
Iteration 127, Percentage: 3.175, Average Loss: 4.707269759159389
Iteration 128, Percentage: 3.2, Average Loss: 4.522561244557544
Iteration 129, Percentage: 3.225, Average Loss: 4.53843616567007
Iteration 130, Percentage: 3.25, Average Loss: 4.370283496420027
Iteration 131, Percentage: 3.2750000000000004, Average Loss: 4.40782362581654
Iteration 132, Percentage: 3.3000000000000003, Average Loss: 4.49354523355692
Iteration 133, Percentage: 3.325, Average Loss: 4.1777188432669
Iteration 134, Percentage: 3.35, Average Loss: 4.287301296938905
Iteration 135, Percentage: 3.375, Average Loss: 4.300609445411693
Iteration 136, Percentage: 3.4000000000000004, Average Loss: 4.317210147043272
Iteration 137, Percentage: 3.4250000000000003, Average Loss: 4.463040717278225
Iteration 138, Percentage: 3.45, Ave

Iteration 248, Percentage: 6.2, Average Loss: 4.246771013783752
Iteration 249, Percentage: 6.225, Average Loss: 3.799700184887092
Iteration 250, Percentage: 6.25, Average Loss: 4.177906242218644
Iteration 251, Percentage: 6.275, Average Loss: 4.082356225650403
Iteration 252, Percentage: 6.3, Average Loss: 4.027624241151791
Iteration 253, Percentage: 6.325, Average Loss: 4.0845619822183
Iteration 254, Percentage: 6.35, Average Loss: 4.03814251364983
Iteration 255, Percentage: 6.375, Average Loss: 3.8061416326335777
Iteration 256, Percentage: 6.4, Average Loss: 3.8638889686284976
Iteration 257, Percentage: 6.425, Average Loss: 3.982108518367505
Iteration 258, Percentage: 6.45, Average Loss: 4.082198551350936
Iteration 259, Percentage: 6.4750000000000005, Average Loss: 3.903280618084218
Iteration 260, Percentage: 6.5, Average Loss: 3.8099852491640394
Iteration 261, Percentage: 6.525, Average Loss: 4.170815387079637
Iteration 262, Percentage: 6.550000000000001, Average Loss: 3.995262487199

Iteration 368, Percentage: 9.2, Average Loss: 4.008531139633299
Iteration 369, Percentage: 9.225, Average Loss: 3.8449268379739223
Iteration 370, Percentage: 9.25, Average Loss: 3.915085432134144
Iteration 371, Percentage: 9.275, Average Loss: 3.9074711708724497
Iteration 372, Percentage: 9.3, Average Loss: 3.6612245395760454
Iteration 373, Percentage: 9.325, Average Loss: 4.210334830706818
Iteration 374, Percentage: 9.35, Average Loss: 3.4686821729140846
Iteration 375, Percentage: 9.375, Average Loss: 4.154268293827772
Iteration 376, Percentage: 9.4, Average Loss: 3.94263782720184
Iteration 377, Percentage: 9.425, Average Loss: 3.940402298273208
Iteration 378, Percentage: 9.45, Average Loss: 3.687457755019393
Iteration 379, Percentage: 9.475, Average Loss: 3.702882561444023
Iteration 380, Percentage: 9.5, Average Loss: 3.721686791222826
Iteration 381, Percentage: 9.525, Average Loss: 4.020056515430966
Iteration 382, Percentage: 9.55, Average Loss: 3.6773814237353353
Iteration 383, Per

Iteration 490, Percentage: 12.25, Average Loss: 3.736992290679347
Iteration 491, Percentage: 12.275, Average Loss: 3.90685877042464
Iteration 492, Percentage: 12.3, Average Loss: 4.009992814173749
Iteration 493, Percentage: 12.325, Average Loss: 3.8902524817167077
Iteration 494, Percentage: 12.35, Average Loss: 3.861231638261174
Iteration 495, Percentage: 12.375, Average Loss: 3.6677014340397336
Iteration 496, Percentage: 12.4, Average Loss: 3.869147916112095
Iteration 497, Percentage: 12.425, Average Loss: 3.625511478652399
Iteration 498, Percentage: 12.45, Average Loss: 3.603498091166954
Iteration 499, Percentage: 12.475, Average Loss: 3.692457383994292
Iteration 500, Percentage: 12.5, Average Loss: 3.9084532900020306
Iteration 501, Percentage: 12.525, Average Loss: 3.6828683243662703
Iteration 502, Percentage: 12.55, Average Loss: 3.8607264810614286
Iteration 503, Percentage: 12.575, Average Loss: 3.805897161749327
Iteration 504, Percentage: 12.6, Average Loss: 3.358877807467555
Ite

Iteration 605, Percentage: 15.125, Average Loss: 3.756878431858703
Iteration 606, Percentage: 15.15, Average Loss: 3.725889656231786
Iteration 607, Percentage: 15.174999999999999, Average Loss: 3.831208350459257
Iteration 608, Percentage: 15.2, Average Loss: 3.843493576048676
Iteration 609, Percentage: 15.225, Average Loss: 3.8068132221887385
Iteration 610, Percentage: 15.25, Average Loss: 3.6020869911857405
Iteration 611, Percentage: 15.275, Average Loss: 3.7610111417821965
Iteration 612, Percentage: 15.299999999999999, Average Loss: 3.663785270080868
Iteration 613, Percentage: 15.325, Average Loss: 3.6279416226478944
Iteration 614, Percentage: 15.35, Average Loss: 3.7885219485974773
Iteration 615, Percentage: 15.375, Average Loss: 3.773885869403738
Iteration 616, Percentage: 15.4, Average Loss: 3.835695487676473
Iteration 617, Percentage: 15.425, Average Loss: 3.803587251171774
Iteration 618, Percentage: 15.45, Average Loss: 3.836345426556153
Iteration 619, Percentage: 15.475, Averag

Iteration 724, Percentage: 18.099999999999998, Average Loss: 3.5052583088766975
Iteration 725, Percentage: 18.125, Average Loss: 3.6842247891307176
Iteration 726, Percentage: 18.15, Average Loss: 3.755692847421751
Iteration 727, Percentage: 18.175, Average Loss: 3.36847269166337
Iteration 728, Percentage: 18.2, Average Loss: 3.476660368587612
Iteration 729, Percentage: 18.224999999999998, Average Loss: 3.5968792913341896
Iteration 730, Percentage: 18.25, Average Loss: 3.8534617199324037
Iteration 731, Percentage: 18.275, Average Loss: 3.498249895822179
Iteration 732, Percentage: 18.3, Average Loss: 3.9701358005404472
Iteration 733, Percentage: 18.325, Average Loss: 3.5690345415579423
Iteration 734, Percentage: 18.35, Average Loss: 3.6226829591290497
Iteration 735, Percentage: 18.375, Average Loss: 3.5131191225769496
Iteration 736, Percentage: 18.4, Average Loss: 3.5981504230464205
Iteration 737, Percentage: 18.425, Average Loss: 3.4359910370065614
Iteration 738, Percentage: 18.45, Aver

Iteration 843, Percentage: 21.075, Average Loss: 3.8740565852803517
Iteration 844, Percentage: 21.099999999999998, Average Loss: 3.37057054927823
Iteration 845, Percentage: 21.125, Average Loss: 3.4762126170334557
Iteration 846, Percentage: 21.15, Average Loss: 3.301451800821295
Iteration 847, Percentage: 21.175, Average Loss: 3.498151582574997
Iteration 848, Percentage: 21.2, Average Loss: 3.6888982568736575
Iteration 849, Percentage: 21.224999999999998, Average Loss: 3.757354627465852
Iteration 850, Percentage: 21.25, Average Loss: 3.5737603904786877
Iteration 851, Percentage: 21.275, Average Loss: 3.516980726590658
Iteration 852, Percentage: 21.3, Average Loss: 3.5200210044293883
Iteration 853, Percentage: 21.325, Average Loss: 3.5363724960775023
Iteration 854, Percentage: 21.349999999999998, Average Loss: 3.5931327137893634
Iteration 855, Percentage: 21.375, Average Loss: 3.6326102009201846
Iteration 856, Percentage: 21.4, Average Loss: 3.615693620872914
Iteration 857, Percentage: 

Iteration 966, Percentage: 24.15, Average Loss: 3.5713237063315217
Iteration 967, Percentage: 24.175, Average Loss: 3.6795611492758664
Iteration 968, Percentage: 24.2, Average Loss: 3.6028051377750585
Iteration 969, Percentage: 24.224999999999998, Average Loss: 3.262374256496591
Iteration 970, Percentage: 24.25, Average Loss: 3.630646933076262
Iteration 971, Percentage: 24.275, Average Loss: 3.3874839414439695
Iteration 972, Percentage: 24.3, Average Loss: 3.7438375784655165
Iteration 973, Percentage: 24.325, Average Loss: 3.388510977665995
Iteration 974, Percentage: 24.349999999999998, Average Loss: 3.518131069358318
Iteration 975, Percentage: 24.375, Average Loss: 3.2929275731238987
Iteration 976, Percentage: 24.4, Average Loss: 3.508760785670294
Iteration 977, Percentage: 24.425, Average Loss: 3.528369865132263
Iteration 978, Percentage: 24.45, Average Loss: 3.506516846223768
Iteration 979, Percentage: 24.474999999999998, Average Loss: 3.6301967104793422
Iteration 980, Percentage: 2

Iteration 1086, Percentage: 27.150000000000002, Average Loss: 3.465429635718465
Iteration 1087, Percentage: 27.175, Average Loss: 3.395909821852786
Iteration 1088, Percentage: 27.200000000000003, Average Loss: 3.1821835131555294
Iteration 1089, Percentage: 27.224999999999998, Average Loss: 3.643518547373018
Iteration 1090, Percentage: 27.250000000000004, Average Loss: 3.404370090314789
Iteration 1091, Percentage: 27.275, Average Loss: 3.35383866589383
Iteration 1092, Percentage: 27.3, Average Loss: 3.5040967982178763
Iteration 1093, Percentage: 27.325, Average Loss: 3.216582526156058
Iteration 1094, Percentage: 27.35, Average Loss: 3.56999656326619
Iteration 1095, Percentage: 27.375, Average Loss: 3.369053530917075
Iteration 1096, Percentage: 27.400000000000002, Average Loss: 3.601747438471999
Iteration 1097, Percentage: 27.425, Average Loss: 3.6471242315255767
Iteration 1098, Percentage: 27.450000000000003, Average Loss: 3.7044164167015645
Iteration 1099, Percentage: 27.47499999999999

Iteration 1200, Percentage: 30.0, Average Loss: 3.4911655249742464
Iteration 1201, Percentage: 30.025000000000002, Average Loss: 3.693902840867081
Iteration 1202, Percentage: 30.049999999999997, Average Loss: 3.501536156460643
Iteration 1203, Percentage: 30.075000000000003, Average Loss: 3.5262261390806504
Iteration 1204, Percentage: 30.099999999999998, Average Loss: 3.3506670320889462
Iteration 1205, Percentage: 30.125, Average Loss: 3.206082761893049
Iteration 1206, Percentage: 30.15, Average Loss: 3.537330448763763
Iteration 1207, Percentage: 30.175, Average Loss: 3.264553429391147
Iteration 1208, Percentage: 30.2, Average Loss: 3.59834542175134
Iteration 1209, Percentage: 30.225, Average Loss: 3.407441723299664
Iteration 1210, Percentage: 30.25, Average Loss: 3.5468924208583994
Iteration 1211, Percentage: 30.275000000000002, Average Loss: 3.209777322821319
Iteration 1212, Percentage: 30.3, Average Loss: 3.2190513108864387
Iteration 1213, Percentage: 30.325000000000003, Average Loss

Iteration 1314, Percentage: 32.85, Average Loss: 3.2643768033407574
Iteration 1315, Percentage: 32.875, Average Loss: 3.4227821547439228
Iteration 1316, Percentage: 32.9, Average Loss: 3.315968498455591
Iteration 1317, Percentage: 32.925, Average Loss: 3.342042712181637
Iteration 1318, Percentage: 32.95, Average Loss: 3.538895480353717
Iteration 1319, Percentage: 32.975, Average Loss: 3.3753353692677637
Iteration 1320, Percentage: 33.0, Average Loss: 3.2836928849418956
Iteration 1321, Percentage: 33.025, Average Loss: 3.1837498952895356
Iteration 1322, Percentage: 33.050000000000004, Average Loss: 3.441187068647177
Iteration 1323, Percentage: 33.074999999999996, Average Loss: 3.0922438203294234
Iteration 1324, Percentage: 33.1, Average Loss: 3.2458402581882284
Iteration 1325, Percentage: 33.125, Average Loss: 3.42175544728642
Iteration 1326, Percentage: 33.15, Average Loss: 3.3197748128318776
Iteration 1327, Percentage: 33.175, Average Loss: 3.285954905130196
Iteration 1328, Percentage

Iteration 1433, Percentage: 35.825, Average Loss: 3.183939540331389
Iteration 1434, Percentage: 35.85, Average Loss: 3.261479772774813
Iteration 1435, Percentage: 35.875, Average Loss: 3.4369582767688613
Iteration 1436, Percentage: 35.9, Average Loss: 3.3735403858969364
Iteration 1437, Percentage: 35.925000000000004, Average Loss: 3.6508747191142166
Iteration 1438, Percentage: 35.949999999999996, Average Loss: 3.4492261600084513
Iteration 1439, Percentage: 35.975, Average Loss: 3.2190529116582756
Iteration 1440, Percentage: 36.0, Average Loss: 3.277371060095667
Iteration 1441, Percentage: 36.025, Average Loss: 3.574832677529621
Iteration 1442, Percentage: 36.05, Average Loss: 3.2521326428712265
Iteration 1443, Percentage: 36.075, Average Loss: 3.2819344627303835
Iteration 1444, Percentage: 36.1, Average Loss: 3.1170206480083595
Iteration 1445, Percentage: 36.125, Average Loss: 3.4048809057064875
Iteration 1446, Percentage: 36.15, Average Loss: 3.496782213769691
Iteration 1447, Percenta

Iteration 1551, Percentage: 38.775, Average Loss: 3.3104632666974485
Iteration 1552, Percentage: 38.800000000000004, Average Loss: 3.6396263329302796
Iteration 1553, Percentage: 38.824999999999996, Average Loss: 3.2897130853748267
Iteration 1554, Percentage: 38.85, Average Loss: 3.2083667384251022
Iteration 1555, Percentage: 38.875, Average Loss: 3.1372204054338977
Iteration 1556, Percentage: 38.9, Average Loss: 3.5335256452019594
Iteration 1557, Percentage: 38.925, Average Loss: 3.4951705154759334
Iteration 1558, Percentage: 38.95, Average Loss: 3.0842935694692835
Iteration 1559, Percentage: 38.975, Average Loss: 3.2469629835594898
Iteration 1560, Percentage: 39.0, Average Loss: 3.1089704138610292
Iteration 1561, Percentage: 39.025, Average Loss: 3.3294948808912
Iteration 1562, Percentage: 39.050000000000004, Average Loss: 3.4498461932484936
Iteration 1563, Percentage: 39.074999999999996, Average Loss: 3.25922105363696
Iteration 1564, Percentage: 39.1, Average Loss: 2.921162301135044


Iteration 1669, Percentage: 41.725, Average Loss: 3.180650547267796
Iteration 1670, Percentage: 41.75, Average Loss: 3.1615894005989262
Iteration 1671, Percentage: 41.775, Average Loss: 2.9737792051008602
Iteration 1672, Percentage: 41.8, Average Loss: 3.2853858805579867
Iteration 1673, Percentage: 41.825, Average Loss: 3.1076301002734765
Iteration 1674, Percentage: 41.85, Average Loss: 3.155852662968191
Iteration 1675, Percentage: 41.875, Average Loss: 3.420539110773781
Iteration 1676, Percentage: 41.9, Average Loss: 3.555568313051481
Iteration 1677, Percentage: 41.925000000000004, Average Loss: 3.4308895168400766
Iteration 1678, Percentage: 41.949999999999996, Average Loss: 2.9800146589609415
Iteration 1679, Percentage: 41.975, Average Loss: 3.1493841532511504
Iteration 1680, Percentage: 42.0, Average Loss: 3.0796544022571957
Iteration 1681, Percentage: 42.025, Average Loss: 3.298888911513535
Iteration 1682, Percentage: 42.05, Average Loss: 3.519988313322306
Iteration 1683, Percentag

Iteration 1789, Percentage: 44.725, Average Loss: 3.158018927834928
Iteration 1790, Percentage: 44.75, Average Loss: 3.257776073332417
Iteration 1791, Percentage: 44.775, Average Loss: 3.3600442282591882
Iteration 1792, Percentage: 44.800000000000004, Average Loss: 3.496717151392389
Iteration 1793, Percentage: 44.824999999999996, Average Loss: 3.352646982090887
Iteration 1794, Percentage: 44.85, Average Loss: 3.190838565970173
Iteration 1795, Percentage: 44.875, Average Loss: 3.188649113341492
Iteration 1796, Percentage: 44.9, Average Loss: 3.3465135488705826
Iteration 1797, Percentage: 44.925, Average Loss: 3.2794610281730656
Iteration 1798, Percentage: 44.95, Average Loss: 3.373394098971039
Iteration 1799, Percentage: 44.975, Average Loss: 3.2011510001291046
Iteration 1800, Percentage: 45.0, Average Loss: 3.1541110833390786
Iteration 1801, Percentage: 45.025, Average Loss: 3.121031187277017
Iteration 1802, Percentage: 45.050000000000004, Average Loss: 3.100264046270126
Iteration 1803

Iteration 1909, Percentage: 47.725, Average Loss: 2.9845765987335833
Iteration 1910, Percentage: 47.75, Average Loss: 3.1807502618183374
Iteration 1911, Percentage: 47.775, Average Loss: 2.912475284124923
Iteration 1912, Percentage: 47.8, Average Loss: 3.1500537013797616
Iteration 1913, Percentage: 47.825, Average Loss: 2.9602682997752967
Iteration 1914, Percentage: 47.85, Average Loss: 3.1745516040573154
Iteration 1915, Percentage: 47.875, Average Loss: 3.1066213990075657
Iteration 1916, Percentage: 47.9, Average Loss: 3.1727529025039614
Iteration 1917, Percentage: 47.925000000000004, Average Loss: 3.11184133050398
Iteration 1918, Percentage: 47.949999999999996, Average Loss: 3.0833258010057683
Iteration 1919, Percentage: 47.975, Average Loss: 3.2547801483789103
Iteration 1920, Percentage: 48.0, Average Loss: 3.1879450098153144
Iteration 1921, Percentage: 48.025, Average Loss: 3.10549600485109
Iteration 1922, Percentage: 48.05, Average Loss: 3.2273565896705136
Iteration 1923, Percenta

Iteration 2028, Percentage: 50.7, Average Loss: 3.1360182847971214
Iteration 2029, Percentage: 50.724999999999994, Average Loss: 2.9707977483361883
Iteration 2030, Percentage: 50.74999999999999, Average Loss: 3.2361244172591364
Iteration 2031, Percentage: 50.775000000000006, Average Loss: 3.0606684638379096
Iteration 2032, Percentage: 50.8, Average Loss: 3.0004702399496908
Iteration 2033, Percentage: 50.824999999999996, Average Loss: 3.3168005476111415
Iteration 2034, Percentage: 50.849999999999994, Average Loss: 3.186329256791699
Iteration 2035, Percentage: 50.875, Average Loss: 3.198514884228095
Iteration 2036, Percentage: 50.9, Average Loss: 3.003899499711624
Iteration 2037, Percentage: 50.925, Average Loss: 3.233868178774409
Iteration 2038, Percentage: 50.949999999999996, Average Loss: 2.9403701841142054
Iteration 2039, Percentage: 50.975, Average Loss: 3.099134317311986
Iteration 2040, Percentage: 51.0, Average Loss: 2.9820069377549516
Iteration 2041, Percentage: 51.025, Average L

Iteration 2141, Percentage: 53.525, Average Loss: 3.2043019484534057
Iteration 2142, Percentage: 53.55, Average Loss: 3.0557967355517595
Iteration 2143, Percentage: 53.574999999999996, Average Loss: 3.3574752986555914
Iteration 2144, Percentage: 53.6, Average Loss: 3.1184631239622833
Iteration 2145, Percentage: 53.625, Average Loss: 3.011054905755517
Iteration 2146, Percentage: 53.65, Average Loss: 2.887044956870332
Iteration 2147, Percentage: 53.675, Average Loss: 3.168643992479834
Iteration 2148, Percentage: 53.7, Average Loss: 3.0538585137229326
Iteration 2149, Percentage: 53.725, Average Loss: 2.8301242672494844
Iteration 2150, Percentage: 53.75, Average Loss: 2.80692044692114
Iteration 2151, Percentage: 53.77499999999999, Average Loss: 3.1673003243443407
Iteration 2152, Percentage: 53.800000000000004, Average Loss: 3.025409870464807
Iteration 2153, Percentage: 53.825, Average Loss: 3.0739756756275893
Iteration 2154, Percentage: 53.849999999999994, Average Loss: 3.076554905746799
I

Iteration 2255, Percentage: 56.375, Average Loss: 3.186422944547021
Iteration 2256, Percentage: 56.39999999999999, Average Loss: 3.229406491307089
Iteration 2257, Percentage: 56.425000000000004, Average Loss: 2.9825921947020664
Iteration 2258, Percentage: 56.45, Average Loss: 3.052572268029846
Iteration 2259, Percentage: 56.474999999999994, Average Loss: 3.066309322531407
Iteration 2260, Percentage: 56.49999999999999, Average Loss: 2.8992416263387946
Iteration 2261, Percentage: 56.525000000000006, Average Loss: 3.0861605815374977
Iteration 2262, Percentage: 56.55, Average Loss: 3.071454723998157
Iteration 2263, Percentage: 56.574999999999996, Average Loss: 2.9080495290933563
Iteration 2264, Percentage: 56.599999999999994, Average Loss: 3.0820506378176575
Iteration 2265, Percentage: 56.625, Average Loss: 2.8513929572004777
Iteration 2266, Percentage: 56.65, Average Loss: 3.03095661240655
Iteration 2267, Percentage: 56.675, Average Loss: 3.483992670375602
Iteration 2268, Percentage: 56.6

Iteration 2370, Percentage: 59.25, Average Loss: 3.0137141091227533
Iteration 2371, Percentage: 59.275, Average Loss: 3.01410566372891
Iteration 2372, Percentage: 59.3, Average Loss: 2.9625741518939597
Iteration 2373, Percentage: 59.325, Average Loss: 3.006168577023914
Iteration 2374, Percentage: 59.35, Average Loss: 3.168620836671169
Iteration 2375, Percentage: 59.375, Average Loss: 3.029876188858223
Iteration 2376, Percentage: 59.4, Average Loss: 2.9707951484060917
Iteration 2377, Percentage: 59.425, Average Loss: 2.8156824836350456
Iteration 2378, Percentage: 59.45, Average Loss: 2.902621415091871
Iteration 2379, Percentage: 59.475, Average Loss: 3.0449702823762252
Iteration 2380, Percentage: 59.5, Average Loss: 3.1136549111792946
Iteration 2381, Percentage: 59.52499999999999, Average Loss: 2.882590206719051
Iteration 2382, Percentage: 59.550000000000004, Average Loss: 3.0002281661727466
Iteration 2383, Percentage: 59.575, Average Loss: 3.203620347633495
Iteration 2384, Percentage: 

Iteration 2485, Percentage: 62.125, Average Loss: 3.1225777903237217
Iteration 2486, Percentage: 62.150000000000006, Average Loss: 3.1682758839490512
Iteration 2487, Percentage: 62.175000000000004, Average Loss: 3.2517170739543633
Iteration 2488, Percentage: 62.2, Average Loss: 3.053267594859704
Iteration 2489, Percentage: 62.224999999999994, Average Loss: 3.075713383800843
Iteration 2490, Percentage: 62.25000000000001, Average Loss: 2.8218260161929924
Iteration 2491, Percentage: 62.275000000000006, Average Loss: 2.931440365278382
Iteration 2492, Percentage: 62.3, Average Loss: 3.0561662180354046
Iteration 2493, Percentage: 62.324999999999996, Average Loss: 3.133639098934518
Iteration 2494, Percentage: 62.35000000000001, Average Loss: 3.2923458156371135
Iteration 2495, Percentage: 62.375, Average Loss: 3.3126639167398766
Iteration 2496, Percentage: 62.4, Average Loss: 2.8653163934112205
Iteration 2497, Percentage: 62.425, Average Loss: 3.01641638091349
Iteration 2498, Percentage: 62.45

Iteration 2603, Percentage: 65.075, Average Loss: 2.7621029720764216
Iteration 2604, Percentage: 65.10000000000001, Average Loss: 3.022746167509329
Iteration 2605, Percentage: 65.125, Average Loss: 3.0290149447948713
Iteration 2606, Percentage: 65.14999999999999, Average Loss: 3.0907225170613906
Iteration 2607, Percentage: 65.17500000000001, Average Loss: 2.7790546783893455
Iteration 2608, Percentage: 65.2, Average Loss: 2.985991682829883
Iteration 2609, Percentage: 65.225, Average Loss: 3.0815879799248207
Iteration 2610, Percentage: 65.25, Average Loss: 2.8633786398814736
Iteration 2611, Percentage: 65.275, Average Loss: 2.9642012330237777
Iteration 2612, Percentage: 65.3, Average Loss: 2.9341860380498024
Iteration 2613, Percentage: 65.325, Average Loss: 2.8664575015143745
Iteration 2614, Percentage: 65.35, Average Loss: 2.9116818859910545
Iteration 2615, Percentage: 65.375, Average Loss: 2.8445179400961007
Iteration 2616, Percentage: 65.4, Average Loss: 2.993432072855586
Iteration 26

Iteration 2722, Percentage: 68.05, Average Loss: 2.6274313832948377
Iteration 2723, Percentage: 68.075, Average Loss: 2.9145918409243574
Iteration 2724, Percentage: 68.10000000000001, Average Loss: 3.1877335160794966
Iteration 2725, Percentage: 68.125, Average Loss: 3.0631153917380223
Iteration 2726, Percentage: 68.15, Average Loss: 2.8201541195491773
Iteration 2727, Percentage: 68.175, Average Loss: 3.0002360264651284
Iteration 2728, Percentage: 68.2, Average Loss: 2.7425295715475113
Iteration 2729, Percentage: 68.22500000000001, Average Loss: 3.2548233351491866
Iteration 2730, Percentage: 68.25, Average Loss: 2.9166119268095523
Iteration 2731, Percentage: 68.27499999999999, Average Loss: 2.8297854021881816
Iteration 2732, Percentage: 68.30000000000001, Average Loss: 2.897031468747652
Iteration 2733, Percentage: 68.325, Average Loss: 2.9183967869027225
Iteration 2734, Percentage: 68.35, Average Loss: 3.0990631500472237
Iteration 2735, Percentage: 68.375, Average Loss: 3.10367393864772

Iteration 2842, Percentage: 71.05, Average Loss: 2.8384990022396264
Iteration 2843, Percentage: 71.075, Average Loss: 2.836273159367234
Iteration 2844, Percentage: 71.1, Average Loss: 3.192573163464171
Iteration 2845, Percentage: 71.125, Average Loss: 2.768073086378787
Iteration 2846, Percentage: 71.15, Average Loss: 2.8675335623929277
Iteration 2847, Percentage: 71.175, Average Loss: 3.084599138562416
Iteration 2848, Percentage: 71.2, Average Loss: 2.7504817148073193
Iteration 2849, Percentage: 71.22500000000001, Average Loss: 2.886747251994884
Iteration 2850, Percentage: 71.25, Average Loss: 3.150974417347461
Iteration 2851, Percentage: 71.275, Average Loss: 3.035718551851236
Iteration 2852, Percentage: 71.3, Average Loss: 2.8185736899102727
Iteration 2853, Percentage: 71.325, Average Loss: 2.8763652146749554
Iteration 2854, Percentage: 71.35000000000001, Average Loss: 2.9014335686664543
Iteration 2855, Percentage: 71.375, Average Loss: 2.8452144596570483
Iteration 2856, Percentage: 

Iteration 2964, Percentage: 74.1, Average Loss: 2.8487178022389963
Iteration 2965, Percentage: 74.125, Average Loss: 2.895419756610921
Iteration 2966, Percentage: 74.15, Average Loss: 2.7652550240340967
Iteration 2967, Percentage: 74.175, Average Loss: 2.966804581013422
Iteration 2968, Percentage: 74.2, Average Loss: 3.1167627916966687
Iteration 2969, Percentage: 74.225, Average Loss: 2.9467445979138804
Iteration 2970, Percentage: 74.25, Average Loss: 2.679543932342942
Iteration 2971, Percentage: 74.275, Average Loss: 2.704494328484956
Iteration 2972, Percentage: 74.3, Average Loss: 2.6561659496464327
Iteration 2973, Percentage: 74.325, Average Loss: 2.821262700564788
Iteration 2974, Percentage: 74.35000000000001, Average Loss: 2.636015704730287
Iteration 2975, Percentage: 74.375, Average Loss: 2.6811753761672237
Iteration 2976, Percentage: 74.4, Average Loss: 2.956419830925958
Iteration 2977, Percentage: 74.425, Average Loss: 2.7801553012288815
Iteration 2978, Percentage: 74.45, Avera

Iteration 3084, Percentage: 77.10000000000001, Average Loss: 2.6844784137840403
Iteration 3085, Percentage: 77.125, Average Loss: 2.6060928125655103
Iteration 3086, Percentage: 77.14999999999999, Average Loss: 2.8454667977307495
Iteration 3087, Percentage: 77.17500000000001, Average Loss: 3.0718814565430774
Iteration 3088, Percentage: 77.2, Average Loss: 2.9261227558146707
Iteration 3089, Percentage: 77.225, Average Loss: 2.6315083261904424
Iteration 3090, Percentage: 77.25, Average Loss: 2.7863523856718184
Iteration 3091, Percentage: 77.275, Average Loss: 2.794542973646357
Iteration 3092, Percentage: 77.3, Average Loss: 2.983737064791577
Iteration 3093, Percentage: 77.325, Average Loss: 2.7203834844372605
Iteration 3094, Percentage: 77.35, Average Loss: 2.8337415347527593
Iteration 3095, Percentage: 77.375, Average Loss: 2.726931244929054
Iteration 3096, Percentage: 77.4, Average Loss: 2.8013620874700473
Iteration 3097, Percentage: 77.425, Average Loss: 2.7770186882171854
Iteration 30

Iteration 3205, Percentage: 80.125, Average Loss: 2.9688803622475226
Iteration 3206, Percentage: 80.15, Average Loss: 2.7102897686810157
Iteration 3207, Percentage: 80.175, Average Loss: 2.751064862566004
Iteration 3208, Percentage: 80.2, Average Loss: 2.6999490548668934
Iteration 3209, Percentage: 80.22500000000001, Average Loss: 2.7352512166893885
Iteration 3210, Percentage: 80.25, Average Loss: 2.695139018626069
Iteration 3211, Percentage: 80.27499999999999, Average Loss: 2.830566843447741
Iteration 3212, Percentage: 80.30000000000001, Average Loss: 2.5657827160506455
Iteration 3213, Percentage: 80.325, Average Loss: 2.750845739481202
Iteration 3214, Percentage: 80.35, Average Loss: 2.5948400745511093
Iteration 3215, Percentage: 80.375, Average Loss: 2.768879926949269
Iteration 3216, Percentage: 80.4, Average Loss: 2.504568768050796
Iteration 3217, Percentage: 80.425, Average Loss: 2.734811899541163
Iteration 3218, Percentage: 80.45, Average Loss: 2.9711925348131474
Iteration 3219, 

Iteration 3322, Percentage: 83.05, Average Loss: 2.6652811164736074
Iteration 3323, Percentage: 83.075, Average Loss: 2.702864023854569
Iteration 3324, Percentage: 83.1, Average Loss: 2.6838952089260917
Iteration 3325, Percentage: 83.125, Average Loss: 2.6667996858468817
Iteration 3326, Percentage: 83.15, Average Loss: 2.757404763969387
Iteration 3327, Percentage: 83.175, Average Loss: 2.995887811110362
Iteration 3328, Percentage: 83.2, Average Loss: 2.8694052960959575
Iteration 3329, Percentage: 83.22500000000001, Average Loss: 2.9284217856824397
Iteration 3330, Percentage: 83.25, Average Loss: 2.551128963692989
Iteration 3331, Percentage: 83.275, Average Loss: 2.6720001000446523
Iteration 3332, Percentage: 83.3, Average Loss: 2.7783023560627726
Iteration 3333, Percentage: 83.325, Average Loss: 2.682521833411235
Iteration 3334, Percentage: 83.35000000000001, Average Loss: 2.5625294806220076
Iteration 3335, Percentage: 83.375, Average Loss: 2.618466776103344
Iteration 3336, Percentage:

Iteration 3442, Percentage: 86.05000000000001, Average Loss: 2.7238758691882654
Iteration 3443, Percentage: 86.075, Average Loss: 2.7111057292549727
Iteration 3444, Percentage: 86.1, Average Loss: 2.760332075450145
Iteration 3445, Percentage: 86.125, Average Loss: 2.7798115216977433
Iteration 3446, Percentage: 86.15, Average Loss: 2.7267325796784463
Iteration 3447, Percentage: 86.175, Average Loss: 2.719320778728981
Iteration 3448, Percentage: 86.2, Average Loss: 2.7397641203496135
Iteration 3449, Percentage: 86.225, Average Loss: 2.84439103046972
Iteration 3450, Percentage: 86.25, Average Loss: 2.92231787069767
Iteration 3451, Percentage: 86.275, Average Loss: 2.662040321409103
Iteration 3452, Percentage: 86.3, Average Loss: 2.428170686279111
Iteration 3453, Percentage: 86.325, Average Loss: 2.6229873047087047
Iteration 3454, Percentage: 86.35000000000001, Average Loss: 3.004677284332807
Iteration 3455, Percentage: 86.375, Average Loss: 2.6390235452414172
Iteration 3456, Percentage: 8

Iteration 3561, Percentage: 89.025, Average Loss: 2.7137964515170703
Iteration 3562, Percentage: 89.05, Average Loss: 2.6746283523428755
Iteration 3563, Percentage: 89.075, Average Loss: 2.545532242863009
Iteration 3564, Percentage: 89.1, Average Loss: 2.698339527351658
Iteration 3565, Percentage: 89.125, Average Loss: 2.3673296357160156
Iteration 3566, Percentage: 89.14999999999999, Average Loss: 2.6129816100769436
Iteration 3567, Percentage: 89.17500000000001, Average Loss: 2.526845469187582
Iteration 3568, Percentage: 89.2, Average Loss: 2.733652295470238
Iteration 3569, Percentage: 89.225, Average Loss: 2.8273858432686985
Iteration 3570, Percentage: 89.25, Average Loss: 2.746686631933816
Iteration 3571, Percentage: 89.275, Average Loss: 2.8542636646678052
Iteration 3572, Percentage: 89.3, Average Loss: 2.587186428155787
Iteration 3573, Percentage: 89.325, Average Loss: 3.000721425906598
Iteration 3574, Percentage: 89.35, Average Loss: 2.4705487249582134
Iteration 3575, Percentage: 

Iteration 3681, Percentage: 92.025, Average Loss: 2.474843087631303
Iteration 3682, Percentage: 92.05, Average Loss: 2.3531244518880796
Iteration 3683, Percentage: 92.07499999999999, Average Loss: 2.5875372584084517
Iteration 3684, Percentage: 92.10000000000001, Average Loss: 2.698399322490032
Iteration 3685, Percentage: 92.125, Average Loss: 2.5730238245338084
Iteration 3686, Percentage: 92.15, Average Loss: 2.509922843868085
Iteration 3687, Percentage: 92.175, Average Loss: 2.2744317367920948
Iteration 3688, Percentage: 92.2, Average Loss: 2.503662440481834
Iteration 3689, Percentage: 92.225, Average Loss: 2.5012406741941424
Iteration 3690, Percentage: 92.25, Average Loss: 2.4596468849440076
Iteration 3691, Percentage: 92.27499999999999, Average Loss: 2.718143350129466
Iteration 3692, Percentage: 92.30000000000001, Average Loss: 2.654898643089827
Iteration 3693, Percentage: 92.325, Average Loss: 2.6438569180483436
Iteration 3694, Percentage: 92.35, Average Loss: 2.784280123765072
Ite

Iteration 3799, Percentage: 94.975, Average Loss: 2.4529102406874297
Iteration 3800, Percentage: 95.0, Average Loss: 2.4378464211876194
Iteration 3801, Percentage: 95.025, Average Loss: 2.713348522266315
Iteration 3802, Percentage: 95.05, Average Loss: 2.413558646422329
Iteration 3803, Percentage: 95.075, Average Loss: 2.7320666818736266
Iteration 3804, Percentage: 95.1, Average Loss: 2.6494355639542144
Iteration 3805, Percentage: 95.125, Average Loss: 2.497802441516383
Iteration 3806, Percentage: 95.15, Average Loss: 2.6837576635795948
Iteration 3807, Percentage: 95.175, Average Loss: 2.5220810942059977
Iteration 3808, Percentage: 95.19999999999999, Average Loss: 2.43879188220948
Iteration 3809, Percentage: 95.22500000000001, Average Loss: 2.641736378573996
Iteration 3810, Percentage: 95.25, Average Loss: 2.275121478064868
Iteration 3811, Percentage: 95.275, Average Loss: 2.7484297938007307
Iteration 3812, Percentage: 95.3, Average Loss: 2.523068225959047
Iteration 3813, Percentage: 9

Iteration 3917, Percentage: 97.925, Average Loss: 2.70019668734121
Iteration 3918, Percentage: 97.95, Average Loss: 2.7835243041078215
Iteration 3919, Percentage: 97.975, Average Loss: 2.6631205425490827
Iteration 3920, Percentage: 98.0, Average Loss: 2.456504618715676
Iteration 3921, Percentage: 98.02499999999999, Average Loss: 2.666791170712232
Iteration 3922, Percentage: 98.05, Average Loss: 2.452332375364171
Iteration 3923, Percentage: 98.075, Average Loss: 2.5215837822999445
Iteration 3924, Percentage: 98.1, Average Loss: 2.5001593213584634
Iteration 3925, Percentage: 98.125, Average Loss: 2.625914762603408
Iteration 3926, Percentage: 98.15, Average Loss: 2.5666587971242665
Iteration 3927, Percentage: 98.175, Average Loss: 2.575789745454374
Iteration 3928, Percentage: 98.2, Average Loss: 2.481436178505269
Iteration 3929, Percentage: 98.225, Average Loss: 2.6272999192903
Iteration 3930, Percentage: 98.25, Average Loss: 2.507491853401279
Iteration 3931, Percentage: 98.275, Average L

In [169]:
encoder.eval()
decoder.eval()
searcher = GreedySearchDecoder(encoder, decoder)

In [170]:
evaluateInput(encoder, decoder, searcher, voc)

> Hi! How are you?
Bot: Fine . . . . .
> Tell me something
Bot: I don t know .
> Why don't you know?
Bot: I don t know . . .
> Who are you?
Bot: I m just looking . . .
> What?
Bot: I don t know . . .
> President
Bot: What ? ? ? ?
> I am the president
Bot: You re not gonna be in the truck .
> I am not in the truck
Bot: I ll be there .
> Where are you?
Bot: I m going to talk to you . .
> Don't talk to me
Bot: Shut up ! I m sorry !
> Don't worry
Bot: I m sorry . I m sorry .
> Why are you sorry?
Bot: I m not sure . .
> Do you like movies?
Bot: Yes . . . . .
> What is your favorite?
Bot: My brother . . . .
> Who is your brother?
Bot: My brother . . . .
> You okay?
Bot: Yeah . . . . .
> Tell me a story
Bot: I m sorry . I love you .
> Why are you so selfish?
Bot: I m going to be fine . .
> I'm a possession
Bot: You re a girl ! ! !
> Are you kidding?
Bot: Yes . . . . .
> You got me involved!!!
Bot: I m sorry . I m sorry .
> I have to be home in twenty minutes
Bot: You ve been a good time .
> Y