## Data Cleaning & Processing

In [1]:
# Import required packages
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

In [2]:
lines_filepath = os.path.join("cornell_movie_data", "movie_lines.txt")
conv_filepath = os.path.join("cornell_movie_data", "movie_conversations.txt")

In [3]:
# Load & Preprocess Data
with open(lines_filepath, 'r') as file:
    lines = file.readlines()
for line in lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [4]:
# Split lines into different fields
line_fields = ["lineID", "characterID", "movieID", "character", "text"]
lines = {}
with open(lines_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        lineObj = {}
        for i, field in enumerate(line_fields):
            lineObj[field] = values[i]
        lines[lineObj['lineID']] = lineObj

In [5]:
lines

{'L1045': {'lineID': 'L1045',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': 'They do not!\n'},
 'L1044': {'lineID': 'L1044',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'They do to!\n'},
 'L985': {'lineID': 'L985',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': 'I hope so.\n'},
 'L984': {'lineID': 'L984',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'She okay?\n'},
 'L925': {'lineID': 'L925',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': "Let's go.\n"},
 'L924': {'lineID': 'L924',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'Wow\n'},
 'L872': {'lineID': 'L872',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': "Okay -- you're gonna need to learn how to lie.\n"},
 'L871': {'lineID': 'L871',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'No

In [6]:
conv_fields = ["characterID", "chracter2ID", "movieID", "utteranceIDs"]
conversations = []
with open(conv_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field] = values[i]
        lineIds = eval(convObj["utteranceIDs"])
        
        convObj["lines"] = []
        for lineId in lineIds:
            convObj["lines"].append(lines[lineId])
        conversations.append(convObj)

In [7]:
conversations[0]

{'characterID': 'u0',
 'chracter2ID': 'u2',
 'movieID': 'm0',
 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
 'lines': [{'lineID': 'L194',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
  {'lineID': 'L195',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
  {'lineID': 'L196',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
  {'lineID': 'L197',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}

In [8]:
# extracting pairs of conversations
qa_pairs = []
for conversation in conversations:
    for i in range(len(conversation["lines"]) - 1):
        inputLine = conversation["lines"][i]["text"].strip()
        targetLine = conversation["lines"][i+1]["text"].strip()
        
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])

In [9]:
qa_pairs[0]

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [10]:
# define the file to save
datafile = os.path.join("cornell_movie_data", "formatted_movie_lines.txt")
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
print("\nWriting newly formatted file...")
# Write to a file
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)
print("Writing to file Done.")


Writing newly formatted file...
Writing to file Done.


In [11]:
# Check some of the lines
datafile = os.path.join("cornell_movie_data", "formatted_movie_lines.txt")

with open(datafile, 'rb') as file:
    lines = file.readlines()
for line in lines[:8]:
    print(line.strip())

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you."
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please."
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?"
b"You're asking me out.  That's so cute. What's your name again?\tForget it."
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron."
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does."
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough..."
b'Why?\tUnsolved mystery.  She used to be really po

In [12]:
PAD_token = 0 # User for paddign short sentences
SOS_token = 1 # Start of the sentences
EOS_token = 2 # End of sentences

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
            
    def trim(self, min_count):
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        
        print('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)))
        # Reinitialize Dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3
              
        for word in keep_words:
              self.addWord(word)

In [13]:
# Turn a Unicode string to ASCII
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [14]:
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Replace any .!? by a whitespace + the character --> '!' = ' !' .\1 means the first bracketed group --> [,!?]. r is to 
    # not consider  \1 as a character (r to escape backslash).
    s = re.sub(r"([.!?])", r" \1", s)
    # Remove anhy character that is not a sequence of  lower or upper case letters. + means one more 
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    # Remove a sequence of whitespace characters
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [15]:
datafile = os.path.join('cornell_movie_data', 'formatted_movie_lines.txt')
print("Reading & processing the files.Please wait...")
lines = open(datafile, encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]
print("Done reading.")
voc = Vocabulary("cornell movie dialogs")

Reading & processing the files.Please wait...
Done reading.


In [16]:
# Returns True if both sentense in a pair 'p' are under the MAX_LENGTH threshold
MAX_LENGTH = 10 # maximum sentence length to consider (max words)
def filterPair(p):
    # Input token need to preserve the last word for EOS token
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) < MAX_LENGTH

# filter pair using filterPair consition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [44]:
pairs = [pair for pair in pairs if len(pair)>1]
print("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs = filterPairs(pairs)
print("After filtering, there are {} pairs/conversations".format(len(pairs)))

save_dir = os.path.join("data", "save")

There are 53165 pairs/conversations in the dataset
After filtering, there are 53165 pairs/conversations


In [18]:
# Loop through each pair of  and add the question and reply sentence to teh vocabulary
for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
print("Counted words:", voc.num_words)
for pair in pairs[:10]:
    print(pair)

Counted words: 18008
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [19]:
MIN_COUNT = 3 # minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # filter out the pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
                
        # Only keep the pairs that do not contain trimmed word(s) in their input or output sentense
        if keep_input and keep_output:
            keep_pairs.append(pair)
            
    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [20]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

In [21]:
# Test the above function
indexesFromSentence(voc, pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [22]:
# Define some sample for testing
inp = []
out = []
for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
print(inp)
print(len(inp))
indexes = [indexesFromSentence(voc, sentence) for sentence in inp]
indexes

['there .', 'you have my word . as a gentleman', 'hi .', 'have fun tonight ?', 'well no . . .', 'then that s all you had to say .', 'but', 'do you listen to this crap ?', 'what good stuff ?', 'wow']
10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 22, 6, 2],
 [33, 34, 4, 4, 4, 2],
 [35, 36, 37, 38, 7, 39, 40, 41, 4, 2],
 [42, 2],
 [47, 7, 48, 40, 45, 49, 6, 2],
 [50, 51, 52, 6, 2],
 [58, 2]]

In [23]:
def zeroPadding(l, fillvalue=0):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [24]:
leng = [len(ind) for ind in indexes]
max(leng)

10

In [25]:
# Test the zeropadding funtion
test_result = zeroPadding(indexes)
print(len(test_result))
test_result

10


[(3, 7, 16, 8, 33, 35, 42, 47, 50, 58),
 (4, 8, 4, 31, 34, 36, 2, 7, 51, 2),
 (2, 9, 2, 22, 4, 37, 0, 48, 52, 0),
 (0, 10, 0, 6, 4, 38, 0, 40, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 45, 2, 0),
 (0, 11, 0, 0, 2, 39, 0, 49, 0, 0),
 (0, 12, 0, 0, 0, 40, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 41, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [26]:
def binaryMatrix(l, value=0):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [27]:
binary_result = binaryMatrix(test_result)
binary_result

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 0, 1, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]

In [28]:
# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

In [29]:
# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [30]:
# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [31]:
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  64,   25,   25,  562,    8],
        [  25,   74,  283,  199,   27],
        [ 236,   25,    7,    4, 1331],
        [   7,   89,   72,    4,    6],
        [  14,  296,    8,    4,    2],
        [ 144,   36,    4,    2,    0],
        [ 139,  129,    2,    0,    0],
        [2663,    4,    0,    0,    0],
        [  66,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([10,  9,  7,  6,  5])
target_variable: tensor([[1592,   50,   25,  571,   25],
        [   4,   47,   80,    4,   89],
        [   2,    7,  250,    4,   44],
        [   0,  118,    2,    4,   96],
        [   0,   98,    0,    2,  159],
        [   0, 3728,    0,    0,  129],
        [   0,    6,    0,    0,    4],
        [   0,    2,    0,    0,    2]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True],
        [False,  True

## seq2seq Model Create

In [32]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [33]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [34]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

## Training the Model

In [36]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    target = target.view(-1, 1)
    gathered_tensor = torch.gather(inp, 1, target)
    crossEntropy = -torch.log(gathered_tensor)
    loss = crossEntropy.masked_select(mask)
    loss = loss.mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [37]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    lengths = lengths.to("cpu")

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

### Training iterations

In [39]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

### Greedy decoding

In [40]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

### Evaluate my text

In [41]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

### Run Model

In [46]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
corpus_name = "cornell"

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


### Run Training

In [None]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9830
Iteration: 2; Percent complete: 0.1%; Average loss: 8.7856
Iteration: 3; Percent complete: 0.1%; Average loss: 8.4752
Iteration: 4; Percent complete: 0.1%; Average loss: 8.0483
Iteration: 5; Percent complete: 0.1%; Average loss: 7.3073
Iteration: 6; Percent complete: 0.1%; Average loss: 6.9690
Iteration: 7; Percent complete: 0.2%; Average loss: 6.4934
Iteration: 8; Percent complete: 0.2%; Average loss: 6.7050
Iteration: 9; Percent complete: 0.2%; Average loss: 6.9038
Iteration: 10; Percent complete: 0.2%; Average loss: 6.2726
Iteration: 11; Percent complete: 0.3%; Average loss: 5.9458
Iteration: 12; Percent complete: 0.3%; Average loss: 5.6011
Iteration: 13; Percent complete: 0.3%; Average loss: 5.3949
Iteration: 14; Percent complete: 0.4%; Average loss: 5.3023
Iteration: 15; Percent complete: 0.4%; Average loss: 4.8455
Iteration: 16; Percent complete: 0.4%

Iteration: 136; Percent complete: 3.4%; Average loss: 3.7978
Iteration: 137; Percent complete: 3.4%; Average loss: 3.5865
Iteration: 138; Percent complete: 3.5%; Average loss: 3.7461
Iteration: 139; Percent complete: 3.5%; Average loss: 3.6112
Iteration: 140; Percent complete: 3.5%; Average loss: 3.6688
Iteration: 141; Percent complete: 3.5%; Average loss: 3.6613
Iteration: 142; Percent complete: 3.5%; Average loss: 3.6764
Iteration: 143; Percent complete: 3.6%; Average loss: 3.6118
Iteration: 144; Percent complete: 3.6%; Average loss: 3.5594
Iteration: 145; Percent complete: 3.6%; Average loss: 3.8060
Iteration: 146; Percent complete: 3.6%; Average loss: 3.6286
Iteration: 147; Percent complete: 3.7%; Average loss: 3.7421
Iteration: 148; Percent complete: 3.7%; Average loss: 3.5956
Iteration: 149; Percent complete: 3.7%; Average loss: 3.6693
Iteration: 150; Percent complete: 3.8%; Average loss: 3.3716
Iteration: 151; Percent complete: 3.8%; Average loss: 3.3549
Iteration: 152; Percent 

Iteration: 271; Percent complete: 6.8%; Average loss: 3.4812
Iteration: 272; Percent complete: 6.8%; Average loss: 3.5721
Iteration: 273; Percent complete: 6.8%; Average loss: 3.1270
Iteration: 274; Percent complete: 6.9%; Average loss: 3.2822
Iteration: 275; Percent complete: 6.9%; Average loss: 3.4033
Iteration: 276; Percent complete: 6.9%; Average loss: 3.1364
Iteration: 277; Percent complete: 6.9%; Average loss: 3.3120
Iteration: 278; Percent complete: 7.0%; Average loss: 3.6723
Iteration: 279; Percent complete: 7.0%; Average loss: 3.5805
Iteration: 280; Percent complete: 7.0%; Average loss: 3.4069
Iteration: 281; Percent complete: 7.0%; Average loss: 3.4618
Iteration: 282; Percent complete: 7.0%; Average loss: 3.7721
Iteration: 283; Percent complete: 7.1%; Average loss: 3.1536
Iteration: 284; Percent complete: 7.1%; Average loss: 3.2484
Iteration: 285; Percent complete: 7.1%; Average loss: 3.3624
Iteration: 286; Percent complete: 7.1%; Average loss: 3.5663
Iteration: 287; Percent 

Iteration: 406; Percent complete: 10.2%; Average loss: 3.2215
Iteration: 407; Percent complete: 10.2%; Average loss: 3.5606
Iteration: 408; Percent complete: 10.2%; Average loss: 3.3740
Iteration: 409; Percent complete: 10.2%; Average loss: 3.1434
Iteration: 410; Percent complete: 10.2%; Average loss: 3.2668
Iteration: 411; Percent complete: 10.3%; Average loss: 3.1630
Iteration: 412; Percent complete: 10.3%; Average loss: 3.3129
Iteration: 413; Percent complete: 10.3%; Average loss: 3.1830
Iteration: 414; Percent complete: 10.3%; Average loss: 3.3524
Iteration: 415; Percent complete: 10.4%; Average loss: 3.2978
Iteration: 416; Percent complete: 10.4%; Average loss: 3.4722
Iteration: 417; Percent complete: 10.4%; Average loss: 3.0961
Iteration: 418; Percent complete: 10.4%; Average loss: 3.0915
Iteration: 419; Percent complete: 10.5%; Average loss: 3.2775
Iteration: 420; Percent complete: 10.5%; Average loss: 3.2659
Iteration: 421; Percent complete: 10.5%; Average loss: 2.8848
Iteratio

Iteration: 539; Percent complete: 13.5%; Average loss: 3.0934
Iteration: 540; Percent complete: 13.5%; Average loss: 3.3663
Iteration: 541; Percent complete: 13.5%; Average loss: 3.0625
Iteration: 542; Percent complete: 13.6%; Average loss: 3.0323
Iteration: 543; Percent complete: 13.6%; Average loss: 3.3513
Iteration: 544; Percent complete: 13.6%; Average loss: 3.2060
Iteration: 545; Percent complete: 13.6%; Average loss: 3.0634
Iteration: 546; Percent complete: 13.7%; Average loss: 3.2004
Iteration: 547; Percent complete: 13.7%; Average loss: 2.8353
Iteration: 548; Percent complete: 13.7%; Average loss: 2.9417
Iteration: 549; Percent complete: 13.7%; Average loss: 3.1364
Iteration: 550; Percent complete: 13.8%; Average loss: 3.1598
Iteration: 551; Percent complete: 13.8%; Average loss: 3.1701
Iteration: 552; Percent complete: 13.8%; Average loss: 3.1980
Iteration: 553; Percent complete: 13.8%; Average loss: 3.0688
Iteration: 554; Percent complete: 13.9%; Average loss: 2.9826
Iteratio

Iteration: 672; Percent complete: 16.8%; Average loss: 2.8832
Iteration: 673; Percent complete: 16.8%; Average loss: 3.4422
Iteration: 674; Percent complete: 16.9%; Average loss: 3.2320
Iteration: 675; Percent complete: 16.9%; Average loss: 3.2185
Iteration: 676; Percent complete: 16.9%; Average loss: 3.0275
Iteration: 677; Percent complete: 16.9%; Average loss: 3.1088
Iteration: 678; Percent complete: 17.0%; Average loss: 3.0993
Iteration: 679; Percent complete: 17.0%; Average loss: 2.8415
Iteration: 680; Percent complete: 17.0%; Average loss: 2.8311
Iteration: 681; Percent complete: 17.0%; Average loss: 3.0678
Iteration: 682; Percent complete: 17.1%; Average loss: 3.0900
Iteration: 683; Percent complete: 17.1%; Average loss: 3.1207
Iteration: 684; Percent complete: 17.1%; Average loss: 3.0140
Iteration: 685; Percent complete: 17.1%; Average loss: 3.0981
Iteration: 686; Percent complete: 17.2%; Average loss: 3.2042
Iteration: 687; Percent complete: 17.2%; Average loss: 3.1716
Iteratio

Iteration: 805; Percent complete: 20.1%; Average loss: 2.9859
Iteration: 806; Percent complete: 20.2%; Average loss: 3.0059
Iteration: 807; Percent complete: 20.2%; Average loss: 3.0400
Iteration: 808; Percent complete: 20.2%; Average loss: 2.7172
Iteration: 809; Percent complete: 20.2%; Average loss: 2.9538
Iteration: 810; Percent complete: 20.2%; Average loss: 3.2464
Iteration: 811; Percent complete: 20.3%; Average loss: 3.1493
Iteration: 812; Percent complete: 20.3%; Average loss: 2.8934
Iteration: 813; Percent complete: 20.3%; Average loss: 3.0580
Iteration: 814; Percent complete: 20.3%; Average loss: 3.1638
Iteration: 815; Percent complete: 20.4%; Average loss: 2.9311
Iteration: 816; Percent complete: 20.4%; Average loss: 2.9518
Iteration: 817; Percent complete: 20.4%; Average loss: 3.1416
Iteration: 818; Percent complete: 20.4%; Average loss: 2.8547
Iteration: 819; Percent complete: 20.5%; Average loss: 2.8486
Iteration: 820; Percent complete: 20.5%; Average loss: 2.9516
Iteratio

Iteration: 938; Percent complete: 23.4%; Average loss: 3.0095
Iteration: 939; Percent complete: 23.5%; Average loss: 3.1120
Iteration: 940; Percent complete: 23.5%; Average loss: 3.0041
Iteration: 941; Percent complete: 23.5%; Average loss: 2.9600
Iteration: 942; Percent complete: 23.5%; Average loss: 2.9218
Iteration: 943; Percent complete: 23.6%; Average loss: 2.7048
Iteration: 944; Percent complete: 23.6%; Average loss: 3.2376
Iteration: 945; Percent complete: 23.6%; Average loss: 2.8471
Iteration: 946; Percent complete: 23.6%; Average loss: 3.2117
Iteration: 947; Percent complete: 23.7%; Average loss: 3.2057
Iteration: 948; Percent complete: 23.7%; Average loss: 2.8874
Iteration: 949; Percent complete: 23.7%; Average loss: 2.9590
Iteration: 950; Percent complete: 23.8%; Average loss: 2.9995
Iteration: 951; Percent complete: 23.8%; Average loss: 2.9586
Iteration: 952; Percent complete: 23.8%; Average loss: 3.0019
Iteration: 953; Percent complete: 23.8%; Average loss: 2.9283
Iteratio

Iteration: 1070; Percent complete: 26.8%; Average loss: 2.8193
Iteration: 1071; Percent complete: 26.8%; Average loss: 2.7372
Iteration: 1072; Percent complete: 26.8%; Average loss: 2.6875
Iteration: 1073; Percent complete: 26.8%; Average loss: 3.1191
Iteration: 1074; Percent complete: 26.9%; Average loss: 2.8429
Iteration: 1075; Percent complete: 26.9%; Average loss: 2.7493
Iteration: 1076; Percent complete: 26.9%; Average loss: 2.9189
Iteration: 1077; Percent complete: 26.9%; Average loss: 2.9475
Iteration: 1078; Percent complete: 27.0%; Average loss: 2.9192
Iteration: 1079; Percent complete: 27.0%; Average loss: 3.0471
Iteration: 1080; Percent complete: 27.0%; Average loss: 3.0316
Iteration: 1081; Percent complete: 27.0%; Average loss: 3.0307
Iteration: 1082; Percent complete: 27.1%; Average loss: 2.8967
Iteration: 1083; Percent complete: 27.1%; Average loss: 2.7292
Iteration: 1084; Percent complete: 27.1%; Average loss: 2.7126
Iteration: 1085; Percent complete: 27.1%; Average loss:

Iteration: 1201; Percent complete: 30.0%; Average loss: 2.6839
Iteration: 1202; Percent complete: 30.0%; Average loss: 2.7714
Iteration: 1203; Percent complete: 30.1%; Average loss: 2.6263
Iteration: 1204; Percent complete: 30.1%; Average loss: 2.9493
Iteration: 1205; Percent complete: 30.1%; Average loss: 3.1816
Iteration: 1206; Percent complete: 30.1%; Average loss: 2.8169
Iteration: 1207; Percent complete: 30.2%; Average loss: 2.9501
Iteration: 1208; Percent complete: 30.2%; Average loss: 3.0129
Iteration: 1209; Percent complete: 30.2%; Average loss: 2.8102
Iteration: 1210; Percent complete: 30.2%; Average loss: 2.7862
Iteration: 1211; Percent complete: 30.3%; Average loss: 2.7214
Iteration: 1212; Percent complete: 30.3%; Average loss: 3.0298
Iteration: 1213; Percent complete: 30.3%; Average loss: 2.9916
Iteration: 1214; Percent complete: 30.3%; Average loss: 2.9631
Iteration: 1215; Percent complete: 30.4%; Average loss: 2.8556
Iteration: 1216; Percent complete: 30.4%; Average loss:

Iteration: 1332; Percent complete: 33.3%; Average loss: 2.7823
Iteration: 1333; Percent complete: 33.3%; Average loss: 2.9287
Iteration: 1334; Percent complete: 33.4%; Average loss: 2.8414
Iteration: 1335; Percent complete: 33.4%; Average loss: 2.6762
Iteration: 1336; Percent complete: 33.4%; Average loss: 2.8250
Iteration: 1337; Percent complete: 33.4%; Average loss: 2.9083
Iteration: 1338; Percent complete: 33.5%; Average loss: 2.4933
Iteration: 1339; Percent complete: 33.5%; Average loss: 3.0799
Iteration: 1340; Percent complete: 33.5%; Average loss: 2.9316
Iteration: 1341; Percent complete: 33.5%; Average loss: 2.8729
Iteration: 1342; Percent complete: 33.6%; Average loss: 2.8779
Iteration: 1343; Percent complete: 33.6%; Average loss: 2.6607
Iteration: 1344; Percent complete: 33.6%; Average loss: 2.5529
Iteration: 1345; Percent complete: 33.6%; Average loss: 2.7181
Iteration: 1346; Percent complete: 33.7%; Average loss: 3.0142
Iteration: 1347; Percent complete: 33.7%; Average loss:

Iteration: 1463; Percent complete: 36.6%; Average loss: 2.7108
Iteration: 1464; Percent complete: 36.6%; Average loss: 2.8435
Iteration: 1465; Percent complete: 36.6%; Average loss: 2.8071
Iteration: 1466; Percent complete: 36.6%; Average loss: 2.6816
Iteration: 1467; Percent complete: 36.7%; Average loss: 2.8195
Iteration: 1468; Percent complete: 36.7%; Average loss: 2.7521
Iteration: 1469; Percent complete: 36.7%; Average loss: 2.4296
Iteration: 1470; Percent complete: 36.8%; Average loss: 2.8549
Iteration: 1471; Percent complete: 36.8%; Average loss: 2.5942
Iteration: 1472; Percent complete: 36.8%; Average loss: 2.6956
Iteration: 1473; Percent complete: 36.8%; Average loss: 2.8331
Iteration: 1474; Percent complete: 36.9%; Average loss: 2.8033
Iteration: 1475; Percent complete: 36.9%; Average loss: 2.9796
Iteration: 1476; Percent complete: 36.9%; Average loss: 2.5894
Iteration: 1477; Percent complete: 36.9%; Average loss: 2.8504
Iteration: 1478; Percent complete: 37.0%; Average loss:

Iteration: 1594; Percent complete: 39.9%; Average loss: 2.6692
Iteration: 1595; Percent complete: 39.9%; Average loss: 2.8437
Iteration: 1596; Percent complete: 39.9%; Average loss: 2.5702
Iteration: 1597; Percent complete: 39.9%; Average loss: 2.6933
Iteration: 1598; Percent complete: 40.0%; Average loss: 2.8156
Iteration: 1599; Percent complete: 40.0%; Average loss: 2.8533
Iteration: 1600; Percent complete: 40.0%; Average loss: 2.8416
Iteration: 1601; Percent complete: 40.0%; Average loss: 2.7695
Iteration: 1602; Percent complete: 40.1%; Average loss: 2.6699
Iteration: 1603; Percent complete: 40.1%; Average loss: 3.0486
Iteration: 1604; Percent complete: 40.1%; Average loss: 2.6964
Iteration: 1605; Percent complete: 40.1%; Average loss: 2.9706
Iteration: 1606; Percent complete: 40.2%; Average loss: 2.4853
Iteration: 1607; Percent complete: 40.2%; Average loss: 2.8366
Iteration: 1608; Percent complete: 40.2%; Average loss: 2.7374
Iteration: 1609; Percent complete: 40.2%; Average loss:

Iteration: 1725; Percent complete: 43.1%; Average loss: 2.7153
Iteration: 1726; Percent complete: 43.1%; Average loss: 2.6977
Iteration: 1727; Percent complete: 43.2%; Average loss: 2.7292
Iteration: 1728; Percent complete: 43.2%; Average loss: 2.9522
Iteration: 1729; Percent complete: 43.2%; Average loss: 2.4213
Iteration: 1730; Percent complete: 43.2%; Average loss: 2.5678
Iteration: 1731; Percent complete: 43.3%; Average loss: 2.5969
Iteration: 1732; Percent complete: 43.3%; Average loss: 2.6242
Iteration: 1733; Percent complete: 43.3%; Average loss: 2.8106
Iteration: 1734; Percent complete: 43.4%; Average loss: 2.7520
Iteration: 1735; Percent complete: 43.4%; Average loss: 2.7411
Iteration: 1736; Percent complete: 43.4%; Average loss: 2.6683
Iteration: 1737; Percent complete: 43.4%; Average loss: 2.7442
Iteration: 1738; Percent complete: 43.5%; Average loss: 2.5393
Iteration: 1739; Percent complete: 43.5%; Average loss: 2.8060
Iteration: 1740; Percent complete: 43.5%; Average loss:

Iteration: 1856; Percent complete: 46.4%; Average loss: 2.6162
Iteration: 1857; Percent complete: 46.4%; Average loss: 2.6968
Iteration: 1858; Percent complete: 46.5%; Average loss: 2.5940
Iteration: 1859; Percent complete: 46.5%; Average loss: 2.7615
Iteration: 1860; Percent complete: 46.5%; Average loss: 2.7712
Iteration: 1861; Percent complete: 46.5%; Average loss: 2.7525
Iteration: 1862; Percent complete: 46.6%; Average loss: 2.8722
Iteration: 1863; Percent complete: 46.6%; Average loss: 2.7514
Iteration: 1864; Percent complete: 46.6%; Average loss: 2.6262
Iteration: 1865; Percent complete: 46.6%; Average loss: 2.8592
Iteration: 1866; Percent complete: 46.7%; Average loss: 2.7996
Iteration: 1867; Percent complete: 46.7%; Average loss: 2.6541
Iteration: 1868; Percent complete: 46.7%; Average loss: 2.7694
Iteration: 1869; Percent complete: 46.7%; Average loss: 2.5937
Iteration: 1870; Percent complete: 46.8%; Average loss: 2.4384
Iteration: 1871; Percent complete: 46.8%; Average loss:

Iteration: 1987; Percent complete: 49.7%; Average loss: 2.6554
Iteration: 1988; Percent complete: 49.7%; Average loss: 2.8510
Iteration: 1989; Percent complete: 49.7%; Average loss: 2.5534
Iteration: 1990; Percent complete: 49.8%; Average loss: 2.6883
Iteration: 1991; Percent complete: 49.8%; Average loss: 2.6679
Iteration: 1992; Percent complete: 49.8%; Average loss: 2.5181
Iteration: 1993; Percent complete: 49.8%; Average loss: 2.8272
Iteration: 1994; Percent complete: 49.9%; Average loss: 2.7743
Iteration: 1995; Percent complete: 49.9%; Average loss: 2.5298
Iteration: 1996; Percent complete: 49.9%; Average loss: 2.4962
Iteration: 1997; Percent complete: 49.9%; Average loss: 2.5674
Iteration: 1998; Percent complete: 50.0%; Average loss: 2.6998
Iteration: 1999; Percent complete: 50.0%; Average loss: 2.6612
Iteration: 2000; Percent complete: 50.0%; Average loss: 2.8095
Iteration: 2001; Percent complete: 50.0%; Average loss: 2.4322
Iteration: 2002; Percent complete: 50.0%; Average loss:

Iteration: 2118; Percent complete: 52.9%; Average loss: 2.6525
Iteration: 2119; Percent complete: 53.0%; Average loss: 2.6659
Iteration: 2120; Percent complete: 53.0%; Average loss: 2.9502
Iteration: 2121; Percent complete: 53.0%; Average loss: 2.7734
Iteration: 2122; Percent complete: 53.0%; Average loss: 2.6114
Iteration: 2123; Percent complete: 53.1%; Average loss: 2.2445
Iteration: 2124; Percent complete: 53.1%; Average loss: 2.7025
Iteration: 2125; Percent complete: 53.1%; Average loss: 2.5693
Iteration: 2126; Percent complete: 53.1%; Average loss: 2.7982
Iteration: 2127; Percent complete: 53.2%; Average loss: 2.6026
Iteration: 2128; Percent complete: 53.2%; Average loss: 3.0508
Iteration: 2129; Percent complete: 53.2%; Average loss: 2.4755
Iteration: 2130; Percent complete: 53.2%; Average loss: 2.4493
Iteration: 2131; Percent complete: 53.3%; Average loss: 2.8556
Iteration: 2132; Percent complete: 53.3%; Average loss: 2.7814
Iteration: 2133; Percent complete: 53.3%; Average loss:

Iteration: 2249; Percent complete: 56.2%; Average loss: 2.4526
Iteration: 2250; Percent complete: 56.2%; Average loss: 2.5469
Iteration: 2251; Percent complete: 56.3%; Average loss: 2.4924
Iteration: 2252; Percent complete: 56.3%; Average loss: 2.5940
Iteration: 2253; Percent complete: 56.3%; Average loss: 2.5232
Iteration: 2254; Percent complete: 56.4%; Average loss: 2.6529
Iteration: 2255; Percent complete: 56.4%; Average loss: 2.6590
Iteration: 2256; Percent complete: 56.4%; Average loss: 2.6044
Iteration: 2257; Percent complete: 56.4%; Average loss: 2.3409
Iteration: 2258; Percent complete: 56.5%; Average loss: 2.3783
Iteration: 2259; Percent complete: 56.5%; Average loss: 2.7709
Iteration: 2260; Percent complete: 56.5%; Average loss: 2.6647
Iteration: 2261; Percent complete: 56.5%; Average loss: 2.5562
Iteration: 2262; Percent complete: 56.5%; Average loss: 2.3115
Iteration: 2263; Percent complete: 56.6%; Average loss: 2.7392
Iteration: 2264; Percent complete: 56.6%; Average loss:

Iteration: 2380; Percent complete: 59.5%; Average loss: 2.6097
Iteration: 2381; Percent complete: 59.5%; Average loss: 2.8038
Iteration: 2382; Percent complete: 59.6%; Average loss: 2.4350
Iteration: 2383; Percent complete: 59.6%; Average loss: 2.6563
Iteration: 2384; Percent complete: 59.6%; Average loss: 2.3945
Iteration: 2385; Percent complete: 59.6%; Average loss: 2.7651
Iteration: 2386; Percent complete: 59.7%; Average loss: 2.4249
Iteration: 2387; Percent complete: 59.7%; Average loss: 2.6778
Iteration: 2388; Percent complete: 59.7%; Average loss: 2.8019
Iteration: 2389; Percent complete: 59.7%; Average loss: 2.6481
Iteration: 2390; Percent complete: 59.8%; Average loss: 2.4201
Iteration: 2391; Percent complete: 59.8%; Average loss: 2.2632
Iteration: 2392; Percent complete: 59.8%; Average loss: 2.5472
Iteration: 2393; Percent complete: 59.8%; Average loss: 2.4172
Iteration: 2394; Percent complete: 59.9%; Average loss: 2.4125
Iteration: 2395; Percent complete: 59.9%; Average loss:

Iteration: 2511; Percent complete: 62.8%; Average loss: 2.3790
Iteration: 2512; Percent complete: 62.8%; Average loss: 2.6804
Iteration: 2513; Percent complete: 62.8%; Average loss: 2.3439
Iteration: 2514; Percent complete: 62.8%; Average loss: 2.5561
Iteration: 2515; Percent complete: 62.9%; Average loss: 2.5578
Iteration: 2516; Percent complete: 62.9%; Average loss: 2.6843
Iteration: 2517; Percent complete: 62.9%; Average loss: 2.6263
Iteration: 2518; Percent complete: 62.9%; Average loss: 2.6613
Iteration: 2519; Percent complete: 63.0%; Average loss: 2.4795
Iteration: 2520; Percent complete: 63.0%; Average loss: 2.7570
Iteration: 2521; Percent complete: 63.0%; Average loss: 2.4326
Iteration: 2522; Percent complete: 63.0%; Average loss: 2.3460
Iteration: 2523; Percent complete: 63.1%; Average loss: 2.5433
Iteration: 2524; Percent complete: 63.1%; Average loss: 2.3677
Iteration: 2525; Percent complete: 63.1%; Average loss: 2.2889
Iteration: 2526; Percent complete: 63.1%; Average loss:

Iteration: 2642; Percent complete: 66.0%; Average loss: 2.4772
Iteration: 2643; Percent complete: 66.1%; Average loss: 2.4675
Iteration: 2644; Percent complete: 66.1%; Average loss: 2.8489
Iteration: 2645; Percent complete: 66.1%; Average loss: 2.3673
Iteration: 2646; Percent complete: 66.1%; Average loss: 2.6228
Iteration: 2647; Percent complete: 66.2%; Average loss: 2.4432
Iteration: 2648; Percent complete: 66.2%; Average loss: 2.4545
Iteration: 2649; Percent complete: 66.2%; Average loss: 2.5545
Iteration: 2650; Percent complete: 66.2%; Average loss: 2.6350
Iteration: 2651; Percent complete: 66.3%; Average loss: 2.6295
Iteration: 2652; Percent complete: 66.3%; Average loss: 2.5423
Iteration: 2653; Percent complete: 66.3%; Average loss: 2.4973
Iteration: 2654; Percent complete: 66.3%; Average loss: 2.6212
Iteration: 2655; Percent complete: 66.4%; Average loss: 2.4017
Iteration: 2656; Percent complete: 66.4%; Average loss: 2.4502
Iteration: 2657; Percent complete: 66.4%; Average loss:

Iteration: 2773; Percent complete: 69.3%; Average loss: 2.5889
Iteration: 2774; Percent complete: 69.3%; Average loss: 2.7482
Iteration: 2775; Percent complete: 69.4%; Average loss: 2.3477
Iteration: 2776; Percent complete: 69.4%; Average loss: 2.4093
Iteration: 2777; Percent complete: 69.4%; Average loss: 2.5067
Iteration: 2778; Percent complete: 69.5%; Average loss: 2.4046
Iteration: 2779; Percent complete: 69.5%; Average loss: 2.5883
Iteration: 2780; Percent complete: 69.5%; Average loss: 2.6282
Iteration: 2781; Percent complete: 69.5%; Average loss: 2.4800
Iteration: 2782; Percent complete: 69.5%; Average loss: 2.3876
Iteration: 2783; Percent complete: 69.6%; Average loss: 2.2298
Iteration: 2784; Percent complete: 69.6%; Average loss: 2.6805
Iteration: 2785; Percent complete: 69.6%; Average loss: 2.6548
Iteration: 2786; Percent complete: 69.7%; Average loss: 2.3931
Iteration: 2787; Percent complete: 69.7%; Average loss: 2.5649
Iteration: 2788; Percent complete: 69.7%; Average loss:

Iteration: 2904; Percent complete: 72.6%; Average loss: 2.4309
Iteration: 2905; Percent complete: 72.6%; Average loss: 2.2438
Iteration: 2906; Percent complete: 72.7%; Average loss: 2.4373
Iteration: 2907; Percent complete: 72.7%; Average loss: 2.4169
Iteration: 2908; Percent complete: 72.7%; Average loss: 2.6040
Iteration: 2909; Percent complete: 72.7%; Average loss: 2.2680
Iteration: 2910; Percent complete: 72.8%; Average loss: 2.4107
Iteration: 2911; Percent complete: 72.8%; Average loss: 2.4986
Iteration: 2912; Percent complete: 72.8%; Average loss: 2.5688
Iteration: 2913; Percent complete: 72.8%; Average loss: 2.4629
Iteration: 2914; Percent complete: 72.9%; Average loss: 2.3687
Iteration: 2915; Percent complete: 72.9%; Average loss: 2.3434
Iteration: 2916; Percent complete: 72.9%; Average loss: 2.3934
Iteration: 2917; Percent complete: 72.9%; Average loss: 2.5586
Iteration: 2918; Percent complete: 73.0%; Average loss: 2.5086
Iteration: 2919; Percent complete: 73.0%; Average loss:

Iteration: 3035; Percent complete: 75.9%; Average loss: 2.4199
Iteration: 3036; Percent complete: 75.9%; Average loss: 2.3489
Iteration: 3037; Percent complete: 75.9%; Average loss: 2.3792
Iteration: 3038; Percent complete: 75.9%; Average loss: 2.4149
Iteration: 3039; Percent complete: 76.0%; Average loss: 2.3638
Iteration: 3040; Percent complete: 76.0%; Average loss: 2.1507
Iteration: 3041; Percent complete: 76.0%; Average loss: 2.2132
Iteration: 3042; Percent complete: 76.0%; Average loss: 2.0982
Iteration: 3043; Percent complete: 76.1%; Average loss: 2.4528
Iteration: 3044; Percent complete: 76.1%; Average loss: 2.4310
Iteration: 3045; Percent complete: 76.1%; Average loss: 2.3357
Iteration: 3046; Percent complete: 76.1%; Average loss: 2.4454
Iteration: 3047; Percent complete: 76.2%; Average loss: 2.2791
Iteration: 3048; Percent complete: 76.2%; Average loss: 2.4212
Iteration: 3049; Percent complete: 76.2%; Average loss: 2.5612
Iteration: 3050; Percent complete: 76.2%; Average loss:

Iteration: 3166; Percent complete: 79.1%; Average loss: 2.1797
Iteration: 3167; Percent complete: 79.2%; Average loss: 2.4459
Iteration: 3168; Percent complete: 79.2%; Average loss: 2.4533
Iteration: 3169; Percent complete: 79.2%; Average loss: 2.4544
Iteration: 3170; Percent complete: 79.2%; Average loss: 2.1950
Iteration: 3171; Percent complete: 79.3%; Average loss: 2.2300
Iteration: 3172; Percent complete: 79.3%; Average loss: 2.5782
Iteration: 3173; Percent complete: 79.3%; Average loss: 2.2447
Iteration: 3174; Percent complete: 79.3%; Average loss: 2.2438
Iteration: 3175; Percent complete: 79.4%; Average loss: 2.3665
Iteration: 3176; Percent complete: 79.4%; Average loss: 2.2132
Iteration: 3177; Percent complete: 79.4%; Average loss: 2.2745
Iteration: 3178; Percent complete: 79.5%; Average loss: 2.3456
Iteration: 3179; Percent complete: 79.5%; Average loss: 2.3355
Iteration: 3180; Percent complete: 79.5%; Average loss: 2.1911
Iteration: 3181; Percent complete: 79.5%; Average loss:

Iteration: 3297; Percent complete: 82.4%; Average loss: 2.2802
Iteration: 3298; Percent complete: 82.5%; Average loss: 2.3608
Iteration: 3299; Percent complete: 82.5%; Average loss: 2.4406
Iteration: 3300; Percent complete: 82.5%; Average loss: 2.2747
Iteration: 3301; Percent complete: 82.5%; Average loss: 2.4705
Iteration: 3302; Percent complete: 82.5%; Average loss: 2.1935
Iteration: 3303; Percent complete: 82.6%; Average loss: 2.3963
Iteration: 3304; Percent complete: 82.6%; Average loss: 2.2258
Iteration: 3305; Percent complete: 82.6%; Average loss: 2.4961
Iteration: 3306; Percent complete: 82.7%; Average loss: 2.1771
Iteration: 3307; Percent complete: 82.7%; Average loss: 2.3420
Iteration: 3308; Percent complete: 82.7%; Average loss: 2.4218
Iteration: 3309; Percent complete: 82.7%; Average loss: 2.3178
Iteration: 3310; Percent complete: 82.8%; Average loss: 2.2296
Iteration: 3311; Percent complete: 82.8%; Average loss: 2.5865
Iteration: 3312; Percent complete: 82.8%; Average loss:

Iteration: 3428; Percent complete: 85.7%; Average loss: 2.4739
Iteration: 3429; Percent complete: 85.7%; Average loss: 2.3420
Iteration: 3430; Percent complete: 85.8%; Average loss: 2.0936
Iteration: 3431; Percent complete: 85.8%; Average loss: 2.1772
Iteration: 3432; Percent complete: 85.8%; Average loss: 2.4366
Iteration: 3433; Percent complete: 85.8%; Average loss: 2.4322
Iteration: 3434; Percent complete: 85.9%; Average loss: 2.3566
Iteration: 3435; Percent complete: 85.9%; Average loss: 2.1746
Iteration: 3436; Percent complete: 85.9%; Average loss: 2.2553
Iteration: 3437; Percent complete: 85.9%; Average loss: 2.4053
Iteration: 3438; Percent complete: 86.0%; Average loss: 2.4688
Iteration: 3439; Percent complete: 86.0%; Average loss: 2.3284
Iteration: 3440; Percent complete: 86.0%; Average loss: 2.2311
Iteration: 3441; Percent complete: 86.0%; Average loss: 2.5506
Iteration: 3442; Percent complete: 86.1%; Average loss: 2.3901
Iteration: 3443; Percent complete: 86.1%; Average loss:

Iteration: 3559; Percent complete: 89.0%; Average loss: 2.0831
Iteration: 3560; Percent complete: 89.0%; Average loss: 2.2987
Iteration: 3561; Percent complete: 89.0%; Average loss: 2.4091
Iteration: 3562; Percent complete: 89.0%; Average loss: 2.3424
Iteration: 3563; Percent complete: 89.1%; Average loss: 2.2279
Iteration: 3564; Percent complete: 89.1%; Average loss: 2.1799
Iteration: 3565; Percent complete: 89.1%; Average loss: 2.3252
Iteration: 3566; Percent complete: 89.1%; Average loss: 2.1812
Iteration: 3567; Percent complete: 89.2%; Average loss: 2.2419
Iteration: 3568; Percent complete: 89.2%; Average loss: 2.3286
Iteration: 3569; Percent complete: 89.2%; Average loss: 2.2641
Iteration: 3570; Percent complete: 89.2%; Average loss: 2.4282
Iteration: 3571; Percent complete: 89.3%; Average loss: 2.2841
Iteration: 3572; Percent complete: 89.3%; Average loss: 2.2282
Iteration: 3573; Percent complete: 89.3%; Average loss: 2.3946
Iteration: 3574; Percent complete: 89.3%; Average loss:

Iteration: 3690; Percent complete: 92.2%; Average loss: 2.2337
Iteration: 3691; Percent complete: 92.3%; Average loss: 2.3121
Iteration: 3692; Percent complete: 92.3%; Average loss: 2.2453
Iteration: 3693; Percent complete: 92.3%; Average loss: 2.1822
Iteration: 3694; Percent complete: 92.3%; Average loss: 2.2530
Iteration: 3695; Percent complete: 92.4%; Average loss: 2.2066
Iteration: 3696; Percent complete: 92.4%; Average loss: 2.3642
Iteration: 3697; Percent complete: 92.4%; Average loss: 2.1352
Iteration: 3698; Percent complete: 92.5%; Average loss: 2.2838
Iteration: 3699; Percent complete: 92.5%; Average loss: 2.0768
Iteration: 3700; Percent complete: 92.5%; Average loss: 2.3135
Iteration: 3701; Percent complete: 92.5%; Average loss: 2.3360
Iteration: 3702; Percent complete: 92.5%; Average loss: 2.2075
Iteration: 3703; Percent complete: 92.6%; Average loss: 2.1959
Iteration: 3704; Percent complete: 92.6%; Average loss: 2.0864
Iteration: 3705; Percent complete: 92.6%; Average loss:

Iteration: 3821; Percent complete: 95.5%; Average loss: 2.0760
Iteration: 3822; Percent complete: 95.5%; Average loss: 2.1479
Iteration: 3823; Percent complete: 95.6%; Average loss: 2.3408
Iteration: 3824; Percent complete: 95.6%; Average loss: 2.0688
Iteration: 3825; Percent complete: 95.6%; Average loss: 2.0953
Iteration: 3826; Percent complete: 95.7%; Average loss: 2.0244
Iteration: 3827; Percent complete: 95.7%; Average loss: 2.1621
Iteration: 3828; Percent complete: 95.7%; Average loss: 2.3596
Iteration: 3829; Percent complete: 95.7%; Average loss: 2.2562
Iteration: 3830; Percent complete: 95.8%; Average loss: 2.3052
Iteration: 3831; Percent complete: 95.8%; Average loss: 2.0401
Iteration: 3832; Percent complete: 95.8%; Average loss: 2.1372
Iteration: 3833; Percent complete: 95.8%; Average loss: 2.2753
Iteration: 3834; Percent complete: 95.9%; Average loss: 2.2086
Iteration: 3835; Percent complete: 95.9%; Average loss: 2.1820
Iteration: 3836; Percent complete: 95.9%; Average loss:

Iteration: 3952; Percent complete: 98.8%; Average loss: 2.2477
Iteration: 3953; Percent complete: 98.8%; Average loss: 2.0804
Iteration: 3954; Percent complete: 98.9%; Average loss: 2.2104
Iteration: 3955; Percent complete: 98.9%; Average loss: 2.1933
Iteration: 3956; Percent complete: 98.9%; Average loss: 2.2103
Iteration: 3957; Percent complete: 98.9%; Average loss: 2.2359
Iteration: 3958; Percent complete: 99.0%; Average loss: 2.0085
Iteration: 3959; Percent complete: 99.0%; Average loss: 2.2295
Iteration: 3960; Percent complete: 99.0%; Average loss: 2.1607
Iteration: 3961; Percent complete: 99.0%; Average loss: 2.2020
Iteration: 3962; Percent complete: 99.1%; Average loss: 2.2453
Iteration: 3963; Percent complete: 99.1%; Average loss: 1.9378
Iteration: 3964; Percent complete: 99.1%; Average loss: 2.2990
Iteration: 3965; Percent complete: 99.1%; Average loss: 1.9979
Iteration: 3966; Percent complete: 99.2%; Average loss: 2.3149
Iteration: 3967; Percent complete: 99.2%; Average loss:

In [None]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
# evaluateInput(encoder, decoder, searcher, voc)