## Preparations

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

## Load & Preprocess Data 

In [32]:
corpus_name = "dailydialog"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'r', encoding="utf-8") as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "dialogues_text.txt"))

The kitchen stinks . __eou__ I'll throw out the garbage . __eou__

So Dick , how about getting some coffee for tonight ? __eou__ Coffee ? I don ’ t honestly like that kind of stuff . __eou__ Come on , you can at least try a little , besides your cigarette . __eou__ What ’ s wrong with that ? Cigarette is the thing I go crazy for . __eou__ Not for me , Dick . __eou__

Are things still going badly with your houseguest ? __eou__ Getting worse . Now he ’ s eating me out of house and home . I ’ Ve tried talking to him but it all goes in one ear and out the other . He makes himself at home , which is fine . But what really gets me is that yesterday he walked into the living room in the raw and I had company over ! That was the last straw . __eou__ Leo , I really think you ’ re beating around the bush with this guy . I know he used to be your best friend in college , but I really think it ’ s time to lay down the law . __eou__ You ’ re right . Everything is probably going to come to a head to

## Create formatted data file

In [40]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName):
    conversations = []
    with open(fileName, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split("__eou__")
            conversations.append(values)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation[i].strip()
            targetLine = conversation[i+1].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [41]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_dialogues_text.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

print("\nLoading conversations...")
conversations = loadLines(os.path.join(corpus, "dialogues_text.txt"))

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Loading conversations...

Writing newly formatted file...

Sample lines from file:
The kitchen stinks .	I'll throw out the garbage .

So Dick , how about getting some coffee for tonight ?	Coffee ? I don ’ t honestly like that kind of stuff .

Coffee ? I don ’ t honestly like that kind of stuff .	Come on , you can at least try a little , besides your cigarette .

Come on , you can at least try a little , besides your cigarette .	What ’ s wrong with that ? Cigarette is the thing I go crazy for .

What ’ s wrong with that ? Cigarette is the thing I go crazy for .	Not for me , Dick .

Are things still going badly with your houseguest ?	Getting worse . Now he ’ s eating me out of house and home . I ’ Ve tried talking to him but it all goes in one ear and out the other . He makes himself at home , which is fine . But what really gets me is that yesterday he walked into the living room in the raw and I had company over ! That was the last straw .

Getting worse . Now he ’ s eating me out of 

## Load and trim data

In [42]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [43]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

In [44]:
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 89862 sentence pairs
Trimmed to 18843 sentence pairs
Counting words...
Counted words: 5910

pairs:
['the kitchen stinks .', 'i ll throw out the garbage .']
['would you mind waiting a while ?', 'well how long will it be ?']
['can you manage chopsticks ?', 'why not ? see .']
['i m exhausted .', 'okay let s go home .']
['no we don t .', 'how many of you please ?']
['how many of you please ?', 'six including two kids .']
['what kind of food do you like ?', 'i like chinese food .']
['i like chinese food .', 'but your american ?']
['i need a packet of cigarettes please .', 'of course sir no problem .']
['of course sir no problem .', 'thanks .']


### Trimming rarely used words out of vocab

In [45]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 3169 / 5907 = 0.5365
Trimmed from 18843 pairs to 15677, 0.8320 of total


## Prepare Data for Models

In [46]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  18,  132,  432,   78,  336],
        [   5,    5,   96,  276,    5],
        [  63,  243,    3,  110,    2],
        [  15,    3, 1055,    3,    0],
        [2309, 1726,  196, 1088,    0],
        [ 368,  253,   64,    5,    0],
        [ 175,   69,  439,    2,    0],
        [2880,  468,    5,    0,    0],
        [   5,    5,    2,    0,    0],
        [   2,    2,    0,    0,    0]])
lengths: tensor([10, 10,  9,  7,  3])
target_variable: tensor([[ 123,   84,   81,    6,   27],
        [  12, 1228,   73,   29,   38],
        [   5,   80,   12,    5,   39],
        [   2,  123,  154,    2,   12],
        [   0,   12,    5,    0,  239],
        [   0,    5,    2,    0, 1703],
        [   0,    2,    0,    0,   17],
        [   0,    0,    0,    0,    2]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [False,  True

## Seq2Seq

 ### Encoder

In [47]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

### Decoder

In [48]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [49]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

## Training 

In [50]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

Single training iteration -> teacher forcing and gradient clipping

In [51]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    lengths = lengths.to("cpu")

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

Training iterations

In [52]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

## Gready decoding for generating sentences

In [53]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

## Evaluation of model

In [54]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

## Run model

possible to laod from checkpoint

In [55]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


## Train !

In [56]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.0671
Iteration: 2; Percent complete: 0.1%; Average loss: 7.9697
Iteration: 3; Percent complete: 0.1%; Average loss: 7.8163
Iteration: 4; Percent complete: 0.1%; Average loss: 7.5624
Iteration: 5; Percent complete: 0.1%; Average loss: 7.1904
Iteration: 6; Percent complete: 0.1%; Average loss: 6.6989
Iteration: 7; Percent complete: 0.2%; Average loss: 6.3878
Iteration: 8; Percent complete: 0.2%; Average loss: 6.4371
Iteration: 9; Percent complete: 0.2%; Average loss: 6.2954
Iteration: 10; Percent complete: 0.2%; Average loss: 5.8307
Iteration: 11; Percent complete: 0.3%; Average loss: 5.5280
Iteration: 12; Percent complete: 0.3%; Average loss: 5.5203
Iteration: 13; Percent complete: 0.3%; Average loss: 5.4158
Iteration: 14; Percent complete: 0.4%; Average loss: 5.3839
Iteration: 15; Percent complete: 0.4%; Average loss: 5.1089
Iteration: 16; Percent complete: 0.4%

Iteration: 136; Percent complete: 3.4%; Average loss: 4.0420
Iteration: 137; Percent complete: 3.4%; Average loss: 3.8476
Iteration: 138; Percent complete: 3.5%; Average loss: 4.0541
Iteration: 139; Percent complete: 3.5%; Average loss: 3.9226
Iteration: 140; Percent complete: 3.5%; Average loss: 3.8385
Iteration: 141; Percent complete: 3.5%; Average loss: 3.9271
Iteration: 142; Percent complete: 3.5%; Average loss: 4.0066
Iteration: 143; Percent complete: 3.6%; Average loss: 4.0675
Iteration: 144; Percent complete: 3.6%; Average loss: 3.9253
Iteration: 145; Percent complete: 3.6%; Average loss: 3.9525
Iteration: 146; Percent complete: 3.6%; Average loss: 3.9330
Iteration: 147; Percent complete: 3.7%; Average loss: 3.8939
Iteration: 148; Percent complete: 3.7%; Average loss: 3.9568
Iteration: 149; Percent complete: 3.7%; Average loss: 3.8698
Iteration: 150; Percent complete: 3.8%; Average loss: 3.6560
Iteration: 151; Percent complete: 3.8%; Average loss: 3.7619
Iteration: 152; Percent 

Iteration: 271; Percent complete: 6.8%; Average loss: 3.5904
Iteration: 272; Percent complete: 6.8%; Average loss: 3.5888
Iteration: 273; Percent complete: 6.8%; Average loss: 3.6322
Iteration: 274; Percent complete: 6.9%; Average loss: 3.6205
Iteration: 275; Percent complete: 6.9%; Average loss: 3.4620
Iteration: 276; Percent complete: 6.9%; Average loss: 3.5598
Iteration: 277; Percent complete: 6.9%; Average loss: 3.4742
Iteration: 278; Percent complete: 7.0%; Average loss: 3.4340
Iteration: 279; Percent complete: 7.0%; Average loss: 3.4750
Iteration: 280; Percent complete: 7.0%; Average loss: 3.5118
Iteration: 281; Percent complete: 7.0%; Average loss: 3.6751
Iteration: 282; Percent complete: 7.0%; Average loss: 3.3940
Iteration: 283; Percent complete: 7.1%; Average loss: 3.3955
Iteration: 284; Percent complete: 7.1%; Average loss: 3.4432
Iteration: 285; Percent complete: 7.1%; Average loss: 3.6085
Iteration: 286; Percent complete: 7.1%; Average loss: 3.5547
Iteration: 287; Percent 

Iteration: 406; Percent complete: 10.2%; Average loss: 3.4062
Iteration: 407; Percent complete: 10.2%; Average loss: 3.1839
Iteration: 408; Percent complete: 10.2%; Average loss: 3.2855
Iteration: 409; Percent complete: 10.2%; Average loss: 3.4237
Iteration: 410; Percent complete: 10.2%; Average loss: 3.4918
Iteration: 411; Percent complete: 10.3%; Average loss: 3.1135
Iteration: 412; Percent complete: 10.3%; Average loss: 3.1273
Iteration: 413; Percent complete: 10.3%; Average loss: 3.3236
Iteration: 414; Percent complete: 10.3%; Average loss: 3.2770
Iteration: 415; Percent complete: 10.4%; Average loss: 3.3909
Iteration: 416; Percent complete: 10.4%; Average loss: 3.1655
Iteration: 417; Percent complete: 10.4%; Average loss: 3.4968
Iteration: 418; Percent complete: 10.4%; Average loss: 3.1295
Iteration: 419; Percent complete: 10.5%; Average loss: 3.1153
Iteration: 420; Percent complete: 10.5%; Average loss: 3.1812
Iteration: 421; Percent complete: 10.5%; Average loss: 3.1116
Iteratio

Iteration: 539; Percent complete: 13.5%; Average loss: 3.0672
Iteration: 540; Percent complete: 13.5%; Average loss: 3.0148
Iteration: 541; Percent complete: 13.5%; Average loss: 3.0798
Iteration: 542; Percent complete: 13.6%; Average loss: 2.9404
Iteration: 543; Percent complete: 13.6%; Average loss: 3.1593
Iteration: 544; Percent complete: 13.6%; Average loss: 2.9076
Iteration: 545; Percent complete: 13.6%; Average loss: 2.7963
Iteration: 546; Percent complete: 13.7%; Average loss: 3.1595
Iteration: 547; Percent complete: 13.7%; Average loss: 3.0521
Iteration: 548; Percent complete: 13.7%; Average loss: 3.0408
Iteration: 549; Percent complete: 13.7%; Average loss: 2.8725
Iteration: 550; Percent complete: 13.8%; Average loss: 3.0178
Iteration: 551; Percent complete: 13.8%; Average loss: 3.0831
Iteration: 552; Percent complete: 13.8%; Average loss: 3.1612
Iteration: 553; Percent complete: 13.8%; Average loss: 2.9606
Iteration: 554; Percent complete: 13.9%; Average loss: 2.7631
Iteratio

Iteration: 672; Percent complete: 16.8%; Average loss: 2.9304
Iteration: 673; Percent complete: 16.8%; Average loss: 2.7176
Iteration: 674; Percent complete: 16.9%; Average loss: 3.2420
Iteration: 675; Percent complete: 16.9%; Average loss: 2.9557
Iteration: 676; Percent complete: 16.9%; Average loss: 2.8234
Iteration: 677; Percent complete: 16.9%; Average loss: 2.9852
Iteration: 678; Percent complete: 17.0%; Average loss: 2.8672
Iteration: 679; Percent complete: 17.0%; Average loss: 2.6662
Iteration: 680; Percent complete: 17.0%; Average loss: 2.8898
Iteration: 681; Percent complete: 17.0%; Average loss: 2.7928
Iteration: 682; Percent complete: 17.1%; Average loss: 2.7172
Iteration: 683; Percent complete: 17.1%; Average loss: 2.7043
Iteration: 684; Percent complete: 17.1%; Average loss: 3.0475
Iteration: 685; Percent complete: 17.1%; Average loss: 2.8630
Iteration: 686; Percent complete: 17.2%; Average loss: 2.8860
Iteration: 687; Percent complete: 17.2%; Average loss: 2.5754
Iteratio

Iteration: 805; Percent complete: 20.1%; Average loss: 2.6861
Iteration: 806; Percent complete: 20.2%; Average loss: 2.6548
Iteration: 807; Percent complete: 20.2%; Average loss: 2.8825
Iteration: 808; Percent complete: 20.2%; Average loss: 2.6252
Iteration: 809; Percent complete: 20.2%; Average loss: 2.6750
Iteration: 810; Percent complete: 20.2%; Average loss: 2.5570
Iteration: 811; Percent complete: 20.3%; Average loss: 2.4383
Iteration: 812; Percent complete: 20.3%; Average loss: 2.5316
Iteration: 813; Percent complete: 20.3%; Average loss: 2.6829
Iteration: 814; Percent complete: 20.3%; Average loss: 2.6264
Iteration: 815; Percent complete: 20.4%; Average loss: 2.5490
Iteration: 816; Percent complete: 20.4%; Average loss: 2.5923
Iteration: 817; Percent complete: 20.4%; Average loss: 2.7166
Iteration: 818; Percent complete: 20.4%; Average loss: 2.6992
Iteration: 819; Percent complete: 20.5%; Average loss: 2.7680
Iteration: 820; Percent complete: 20.5%; Average loss: 2.2981
Iteratio

Iteration: 938; Percent complete: 23.4%; Average loss: 2.2383
Iteration: 939; Percent complete: 23.5%; Average loss: 2.6665
Iteration: 940; Percent complete: 23.5%; Average loss: 2.6284
Iteration: 941; Percent complete: 23.5%; Average loss: 2.5248
Iteration: 942; Percent complete: 23.5%; Average loss: 2.3247
Iteration: 943; Percent complete: 23.6%; Average loss: 2.5326
Iteration: 944; Percent complete: 23.6%; Average loss: 2.5991
Iteration: 945; Percent complete: 23.6%; Average loss: 2.4251
Iteration: 946; Percent complete: 23.6%; Average loss: 2.6529
Iteration: 947; Percent complete: 23.7%; Average loss: 2.3724
Iteration: 948; Percent complete: 23.7%; Average loss: 2.2148
Iteration: 949; Percent complete: 23.7%; Average loss: 2.4302
Iteration: 950; Percent complete: 23.8%; Average loss: 2.4737
Iteration: 951; Percent complete: 23.8%; Average loss: 2.5303
Iteration: 952; Percent complete: 23.8%; Average loss: 2.6043
Iteration: 953; Percent complete: 23.8%; Average loss: 2.2899
Iteratio

Iteration: 1070; Percent complete: 26.8%; Average loss: 2.3017
Iteration: 1071; Percent complete: 26.8%; Average loss: 2.3892
Iteration: 1072; Percent complete: 26.8%; Average loss: 2.5794
Iteration: 1073; Percent complete: 26.8%; Average loss: 2.4619
Iteration: 1074; Percent complete: 26.9%; Average loss: 2.2961
Iteration: 1075; Percent complete: 26.9%; Average loss: 2.1524
Iteration: 1076; Percent complete: 26.9%; Average loss: 2.1552
Iteration: 1077; Percent complete: 26.9%; Average loss: 2.2536
Iteration: 1078; Percent complete: 27.0%; Average loss: 2.3513
Iteration: 1079; Percent complete: 27.0%; Average loss: 2.2373
Iteration: 1080; Percent complete: 27.0%; Average loss: 2.1719
Iteration: 1081; Percent complete: 27.0%; Average loss: 2.3484
Iteration: 1082; Percent complete: 27.1%; Average loss: 2.3101
Iteration: 1083; Percent complete: 27.1%; Average loss: 2.3856
Iteration: 1084; Percent complete: 27.1%; Average loss: 2.1575
Iteration: 1085; Percent complete: 27.1%; Average loss:

Iteration: 1201; Percent complete: 30.0%; Average loss: 2.1694
Iteration: 1202; Percent complete: 30.0%; Average loss: 2.1326
Iteration: 1203; Percent complete: 30.1%; Average loss: 2.1741
Iteration: 1204; Percent complete: 30.1%; Average loss: 2.1563
Iteration: 1205; Percent complete: 30.1%; Average loss: 2.2018
Iteration: 1206; Percent complete: 30.1%; Average loss: 2.0500
Iteration: 1207; Percent complete: 30.2%; Average loss: 2.1216
Iteration: 1208; Percent complete: 30.2%; Average loss: 2.2678
Iteration: 1209; Percent complete: 30.2%; Average loss: 2.2526
Iteration: 1210; Percent complete: 30.2%; Average loss: 2.3468
Iteration: 1211; Percent complete: 30.3%; Average loss: 2.0393
Iteration: 1212; Percent complete: 30.3%; Average loss: 2.0443
Iteration: 1213; Percent complete: 30.3%; Average loss: 2.2344
Iteration: 1214; Percent complete: 30.3%; Average loss: 2.0822
Iteration: 1215; Percent complete: 30.4%; Average loss: 2.0432
Iteration: 1216; Percent complete: 30.4%; Average loss:

Iteration: 1332; Percent complete: 33.3%; Average loss: 2.0250
Iteration: 1333; Percent complete: 33.3%; Average loss: 2.0684
Iteration: 1334; Percent complete: 33.4%; Average loss: 1.7647
Iteration: 1335; Percent complete: 33.4%; Average loss: 2.0766
Iteration: 1336; Percent complete: 33.4%; Average loss: 1.7418
Iteration: 1337; Percent complete: 33.4%; Average loss: 1.9080
Iteration: 1338; Percent complete: 33.5%; Average loss: 1.9952
Iteration: 1339; Percent complete: 33.5%; Average loss: 1.9300
Iteration: 1340; Percent complete: 33.5%; Average loss: 2.0442
Iteration: 1341; Percent complete: 33.5%; Average loss: 1.9169
Iteration: 1342; Percent complete: 33.6%; Average loss: 2.1072
Iteration: 1343; Percent complete: 33.6%; Average loss: 1.8462
Iteration: 1344; Percent complete: 33.6%; Average loss: 2.0803
Iteration: 1345; Percent complete: 33.6%; Average loss: 1.8511
Iteration: 1346; Percent complete: 33.7%; Average loss: 2.0866
Iteration: 1347; Percent complete: 33.7%; Average loss:

Iteration: 1463; Percent complete: 36.6%; Average loss: 1.8516
Iteration: 1464; Percent complete: 36.6%; Average loss: 1.7375
Iteration: 1465; Percent complete: 36.6%; Average loss: 1.9950
Iteration: 1466; Percent complete: 36.6%; Average loss: 1.8123
Iteration: 1467; Percent complete: 36.7%; Average loss: 2.1698
Iteration: 1468; Percent complete: 36.7%; Average loss: 1.7891
Iteration: 1469; Percent complete: 36.7%; Average loss: 1.9873
Iteration: 1470; Percent complete: 36.8%; Average loss: 1.7790
Iteration: 1471; Percent complete: 36.8%; Average loss: 1.6765
Iteration: 1472; Percent complete: 36.8%; Average loss: 1.7516
Iteration: 1473; Percent complete: 36.8%; Average loss: 1.8278
Iteration: 1474; Percent complete: 36.9%; Average loss: 1.8981
Iteration: 1475; Percent complete: 36.9%; Average loss: 1.7429
Iteration: 1476; Percent complete: 36.9%; Average loss: 1.7553
Iteration: 1477; Percent complete: 36.9%; Average loss: 1.7231
Iteration: 1478; Percent complete: 37.0%; Average loss:

Iteration: 1594; Percent complete: 39.9%; Average loss: 1.8727
Iteration: 1595; Percent complete: 39.9%; Average loss: 1.7791
Iteration: 1596; Percent complete: 39.9%; Average loss: 1.6996
Iteration: 1597; Percent complete: 39.9%; Average loss: 1.6443
Iteration: 1598; Percent complete: 40.0%; Average loss: 1.5898
Iteration: 1599; Percent complete: 40.0%; Average loss: 1.7396
Iteration: 1600; Percent complete: 40.0%; Average loss: 1.6441
Iteration: 1601; Percent complete: 40.0%; Average loss: 1.5767
Iteration: 1602; Percent complete: 40.1%; Average loss: 1.7004
Iteration: 1603; Percent complete: 40.1%; Average loss: 1.7864
Iteration: 1604; Percent complete: 40.1%; Average loss: 1.6419
Iteration: 1605; Percent complete: 40.1%; Average loss: 1.4683
Iteration: 1606; Percent complete: 40.2%; Average loss: 1.6114
Iteration: 1607; Percent complete: 40.2%; Average loss: 1.6844
Iteration: 1608; Percent complete: 40.2%; Average loss: 1.6814
Iteration: 1609; Percent complete: 40.2%; Average loss:

Iteration: 1725; Percent complete: 43.1%; Average loss: 1.7988
Iteration: 1726; Percent complete: 43.1%; Average loss: 1.5254
Iteration: 1727; Percent complete: 43.2%; Average loss: 1.6316
Iteration: 1728; Percent complete: 43.2%; Average loss: 1.4760
Iteration: 1729; Percent complete: 43.2%; Average loss: 1.4277
Iteration: 1730; Percent complete: 43.2%; Average loss: 1.6129
Iteration: 1731; Percent complete: 43.3%; Average loss: 1.4853
Iteration: 1732; Percent complete: 43.3%; Average loss: 1.5798
Iteration: 1733; Percent complete: 43.3%; Average loss: 1.5088
Iteration: 1734; Percent complete: 43.4%; Average loss: 1.6683
Iteration: 1735; Percent complete: 43.4%; Average loss: 1.3027
Iteration: 1736; Percent complete: 43.4%; Average loss: 1.3987
Iteration: 1737; Percent complete: 43.4%; Average loss: 1.6183
Iteration: 1738; Percent complete: 43.5%; Average loss: 1.4377
Iteration: 1739; Percent complete: 43.5%; Average loss: 1.4909
Iteration: 1740; Percent complete: 43.5%; Average loss:

Iteration: 1856; Percent complete: 46.4%; Average loss: 1.4936
Iteration: 1857; Percent complete: 46.4%; Average loss: 1.4031
Iteration: 1858; Percent complete: 46.5%; Average loss: 1.6768
Iteration: 1859; Percent complete: 46.5%; Average loss: 1.3407
Iteration: 1860; Percent complete: 46.5%; Average loss: 1.4039
Iteration: 1861; Percent complete: 46.5%; Average loss: 1.4754
Iteration: 1862; Percent complete: 46.6%; Average loss: 1.4987
Iteration: 1863; Percent complete: 46.6%; Average loss: 1.5094
Iteration: 1864; Percent complete: 46.6%; Average loss: 1.4498
Iteration: 1865; Percent complete: 46.6%; Average loss: 1.5675
Iteration: 1866; Percent complete: 46.7%; Average loss: 1.3772
Iteration: 1867; Percent complete: 46.7%; Average loss: 1.6120
Iteration: 1868; Percent complete: 46.7%; Average loss: 1.4610
Iteration: 1869; Percent complete: 46.7%; Average loss: 1.4734
Iteration: 1870; Percent complete: 46.8%; Average loss: 1.2349
Iteration: 1871; Percent complete: 46.8%; Average loss:

Iteration: 1987; Percent complete: 49.7%; Average loss: 1.3356
Iteration: 1988; Percent complete: 49.7%; Average loss: 1.3973
Iteration: 1989; Percent complete: 49.7%; Average loss: 1.4494
Iteration: 1990; Percent complete: 49.8%; Average loss: 1.3298
Iteration: 1991; Percent complete: 49.8%; Average loss: 1.4017
Iteration: 1992; Percent complete: 49.8%; Average loss: 1.1453
Iteration: 1993; Percent complete: 49.8%; Average loss: 1.2873
Iteration: 1994; Percent complete: 49.9%; Average loss: 1.4017
Iteration: 1995; Percent complete: 49.9%; Average loss: 1.4493
Iteration: 1996; Percent complete: 49.9%; Average loss: 1.4961
Iteration: 1997; Percent complete: 49.9%; Average loss: 1.5773
Iteration: 1998; Percent complete: 50.0%; Average loss: 1.3137
Iteration: 1999; Percent complete: 50.0%; Average loss: 1.3533
Iteration: 2000; Percent complete: 50.0%; Average loss: 1.3664
Iteration: 2001; Percent complete: 50.0%; Average loss: 1.4205
Iteration: 2002; Percent complete: 50.0%; Average loss:

Iteration: 2118; Percent complete: 52.9%; Average loss: 1.4176
Iteration: 2119; Percent complete: 53.0%; Average loss: 1.2659
Iteration: 2120; Percent complete: 53.0%; Average loss: 1.3875
Iteration: 2121; Percent complete: 53.0%; Average loss: 1.3190
Iteration: 2122; Percent complete: 53.0%; Average loss: 1.4154
Iteration: 2123; Percent complete: 53.1%; Average loss: 1.2459
Iteration: 2124; Percent complete: 53.1%; Average loss: 1.3650
Iteration: 2125; Percent complete: 53.1%; Average loss: 1.1662
Iteration: 2126; Percent complete: 53.1%; Average loss: 1.2217
Iteration: 2127; Percent complete: 53.2%; Average loss: 1.4271
Iteration: 2128; Percent complete: 53.2%; Average loss: 1.2625
Iteration: 2129; Percent complete: 53.2%; Average loss: 1.0859
Iteration: 2130; Percent complete: 53.2%; Average loss: 1.5100
Iteration: 2131; Percent complete: 53.3%; Average loss: 1.2915
Iteration: 2132; Percent complete: 53.3%; Average loss: 1.4736
Iteration: 2133; Percent complete: 53.3%; Average loss:

Iteration: 2249; Percent complete: 56.2%; Average loss: 1.2512
Iteration: 2250; Percent complete: 56.2%; Average loss: 1.2724
Iteration: 2251; Percent complete: 56.3%; Average loss: 1.2259
Iteration: 2252; Percent complete: 56.3%; Average loss: 1.0613
Iteration: 2253; Percent complete: 56.3%; Average loss: 1.4200
Iteration: 2254; Percent complete: 56.4%; Average loss: 1.2159
Iteration: 2255; Percent complete: 56.4%; Average loss: 1.2390
Iteration: 2256; Percent complete: 56.4%; Average loss: 1.0443
Iteration: 2257; Percent complete: 56.4%; Average loss: 1.2048
Iteration: 2258; Percent complete: 56.5%; Average loss: 1.0417
Iteration: 2259; Percent complete: 56.5%; Average loss: 1.2463
Iteration: 2260; Percent complete: 56.5%; Average loss: 1.1881
Iteration: 2261; Percent complete: 56.5%; Average loss: 1.0419
Iteration: 2262; Percent complete: 56.5%; Average loss: 1.0750
Iteration: 2263; Percent complete: 56.6%; Average loss: 1.1612
Iteration: 2264; Percent complete: 56.6%; Average loss:

Iteration: 2380; Percent complete: 59.5%; Average loss: 1.2265
Iteration: 2381; Percent complete: 59.5%; Average loss: 1.0423
Iteration: 2382; Percent complete: 59.6%; Average loss: 1.1242
Iteration: 2383; Percent complete: 59.6%; Average loss: 1.0595
Iteration: 2384; Percent complete: 59.6%; Average loss: 0.9526
Iteration: 2385; Percent complete: 59.6%; Average loss: 1.0856
Iteration: 2386; Percent complete: 59.7%; Average loss: 1.1541
Iteration: 2387; Percent complete: 59.7%; Average loss: 1.0866
Iteration: 2388; Percent complete: 59.7%; Average loss: 1.1584
Iteration: 2389; Percent complete: 59.7%; Average loss: 1.0779
Iteration: 2390; Percent complete: 59.8%; Average loss: 1.1540
Iteration: 2391; Percent complete: 59.8%; Average loss: 1.3038
Iteration: 2392; Percent complete: 59.8%; Average loss: 1.1259
Iteration: 2393; Percent complete: 59.8%; Average loss: 1.0974
Iteration: 2394; Percent complete: 59.9%; Average loss: 1.1156
Iteration: 2395; Percent complete: 59.9%; Average loss:

Iteration: 2511; Percent complete: 62.8%; Average loss: 0.9547
Iteration: 2512; Percent complete: 62.8%; Average loss: 0.9921
Iteration: 2513; Percent complete: 62.8%; Average loss: 0.9892
Iteration: 2514; Percent complete: 62.8%; Average loss: 0.7662
Iteration: 2515; Percent complete: 62.9%; Average loss: 0.9662
Iteration: 2516; Percent complete: 62.9%; Average loss: 0.9600
Iteration: 2517; Percent complete: 62.9%; Average loss: 0.9193
Iteration: 2518; Percent complete: 62.9%; Average loss: 0.9321
Iteration: 2519; Percent complete: 63.0%; Average loss: 1.0201
Iteration: 2520; Percent complete: 63.0%; Average loss: 1.0229
Iteration: 2521; Percent complete: 63.0%; Average loss: 0.9574
Iteration: 2522; Percent complete: 63.0%; Average loss: 1.0635
Iteration: 2523; Percent complete: 63.1%; Average loss: 1.0581
Iteration: 2524; Percent complete: 63.1%; Average loss: 0.9542
Iteration: 2525; Percent complete: 63.1%; Average loss: 0.9409
Iteration: 2526; Percent complete: 63.1%; Average loss:

Iteration: 2642; Percent complete: 66.0%; Average loss: 1.1079
Iteration: 2643; Percent complete: 66.1%; Average loss: 0.7985
Iteration: 2644; Percent complete: 66.1%; Average loss: 0.7765
Iteration: 2645; Percent complete: 66.1%; Average loss: 0.8350
Iteration: 2646; Percent complete: 66.1%; Average loss: 0.8046
Iteration: 2647; Percent complete: 66.2%; Average loss: 0.8543
Iteration: 2648; Percent complete: 66.2%; Average loss: 0.7960
Iteration: 2649; Percent complete: 66.2%; Average loss: 1.0744
Iteration: 2650; Percent complete: 66.2%; Average loss: 0.9524
Iteration: 2651; Percent complete: 66.3%; Average loss: 0.9472
Iteration: 2652; Percent complete: 66.3%; Average loss: 1.0276
Iteration: 2653; Percent complete: 66.3%; Average loss: 0.9839
Iteration: 2654; Percent complete: 66.3%; Average loss: 0.8441
Iteration: 2655; Percent complete: 66.4%; Average loss: 0.8844
Iteration: 2656; Percent complete: 66.4%; Average loss: 0.9119
Iteration: 2657; Percent complete: 66.4%; Average loss:

Iteration: 2773; Percent complete: 69.3%; Average loss: 0.9179
Iteration: 2774; Percent complete: 69.3%; Average loss: 0.8888
Iteration: 2775; Percent complete: 69.4%; Average loss: 0.8861
Iteration: 2776; Percent complete: 69.4%; Average loss: 0.9612
Iteration: 2777; Percent complete: 69.4%; Average loss: 0.9877
Iteration: 2778; Percent complete: 69.5%; Average loss: 0.7649
Iteration: 2779; Percent complete: 69.5%; Average loss: 0.7231
Iteration: 2780; Percent complete: 69.5%; Average loss: 0.9771
Iteration: 2781; Percent complete: 69.5%; Average loss: 0.8380
Iteration: 2782; Percent complete: 69.5%; Average loss: 0.8302
Iteration: 2783; Percent complete: 69.6%; Average loss: 0.7670
Iteration: 2784; Percent complete: 69.6%; Average loss: 0.9503
Iteration: 2785; Percent complete: 69.6%; Average loss: 0.8360
Iteration: 2786; Percent complete: 69.7%; Average loss: 0.7969
Iteration: 2787; Percent complete: 69.7%; Average loss: 0.8042
Iteration: 2788; Percent complete: 69.7%; Average loss:

Iteration: 2904; Percent complete: 72.6%; Average loss: 0.6858
Iteration: 2905; Percent complete: 72.6%; Average loss: 0.6544
Iteration: 2906; Percent complete: 72.7%; Average loss: 0.8524
Iteration: 2907; Percent complete: 72.7%; Average loss: 0.7485
Iteration: 2908; Percent complete: 72.7%; Average loss: 0.8842
Iteration: 2909; Percent complete: 72.7%; Average loss: 0.7427
Iteration: 2910; Percent complete: 72.8%; Average loss: 0.7884
Iteration: 2911; Percent complete: 72.8%; Average loss: 0.8604
Iteration: 2912; Percent complete: 72.8%; Average loss: 0.7250
Iteration: 2913; Percent complete: 72.8%; Average loss: 0.8450
Iteration: 2914; Percent complete: 72.9%; Average loss: 0.8216
Iteration: 2915; Percent complete: 72.9%; Average loss: 0.6278
Iteration: 2916; Percent complete: 72.9%; Average loss: 0.6958
Iteration: 2917; Percent complete: 72.9%; Average loss: 0.8005
Iteration: 2918; Percent complete: 73.0%; Average loss: 0.7463
Iteration: 2919; Percent complete: 73.0%; Average loss:

Iteration: 3035; Percent complete: 75.9%; Average loss: 0.7206
Iteration: 3036; Percent complete: 75.9%; Average loss: 0.7325
Iteration: 3037; Percent complete: 75.9%; Average loss: 0.7470
Iteration: 3038; Percent complete: 75.9%; Average loss: 0.7637
Iteration: 3039; Percent complete: 76.0%; Average loss: 0.7519
Iteration: 3040; Percent complete: 76.0%; Average loss: 0.7788
Iteration: 3041; Percent complete: 76.0%; Average loss: 0.6987
Iteration: 3042; Percent complete: 76.0%; Average loss: 0.6006
Iteration: 3043; Percent complete: 76.1%; Average loss: 0.9154
Iteration: 3044; Percent complete: 76.1%; Average loss: 0.6190
Iteration: 3045; Percent complete: 76.1%; Average loss: 0.8506
Iteration: 3046; Percent complete: 76.1%; Average loss: 0.5848
Iteration: 3047; Percent complete: 76.2%; Average loss: 0.7446
Iteration: 3048; Percent complete: 76.2%; Average loss: 0.6242
Iteration: 3049; Percent complete: 76.2%; Average loss: 0.7130
Iteration: 3050; Percent complete: 76.2%; Average loss:

Iteration: 3166; Percent complete: 79.1%; Average loss: 0.6700
Iteration: 3167; Percent complete: 79.2%; Average loss: 0.6005
Iteration: 3168; Percent complete: 79.2%; Average loss: 0.5865
Iteration: 3169; Percent complete: 79.2%; Average loss: 0.6394
Iteration: 3170; Percent complete: 79.2%; Average loss: 0.7157
Iteration: 3171; Percent complete: 79.3%; Average loss: 0.6618
Iteration: 3172; Percent complete: 79.3%; Average loss: 0.6837
Iteration: 3173; Percent complete: 79.3%; Average loss: 0.6762
Iteration: 3174; Percent complete: 79.3%; Average loss: 0.5861
Iteration: 3175; Percent complete: 79.4%; Average loss: 0.6051
Iteration: 3176; Percent complete: 79.4%; Average loss: 0.6210
Iteration: 3177; Percent complete: 79.4%; Average loss: 0.5841
Iteration: 3178; Percent complete: 79.5%; Average loss: 0.5504
Iteration: 3179; Percent complete: 79.5%; Average loss: 0.6741
Iteration: 3180; Percent complete: 79.5%; Average loss: 0.6466
Iteration: 3181; Percent complete: 79.5%; Average loss:

Iteration: 3297; Percent complete: 82.4%; Average loss: 0.5630
Iteration: 3298; Percent complete: 82.5%; Average loss: 0.5800
Iteration: 3299; Percent complete: 82.5%; Average loss: 0.7129
Iteration: 3300; Percent complete: 82.5%; Average loss: 0.5704
Iteration: 3301; Percent complete: 82.5%; Average loss: 0.6922
Iteration: 3302; Percent complete: 82.5%; Average loss: 0.5950
Iteration: 3303; Percent complete: 82.6%; Average loss: 0.5629
Iteration: 3304; Percent complete: 82.6%; Average loss: 0.6330
Iteration: 3305; Percent complete: 82.6%; Average loss: 0.5639
Iteration: 3306; Percent complete: 82.7%; Average loss: 0.5901
Iteration: 3307; Percent complete: 82.7%; Average loss: 0.6652
Iteration: 3308; Percent complete: 82.7%; Average loss: 0.6445
Iteration: 3309; Percent complete: 82.7%; Average loss: 0.6266
Iteration: 3310; Percent complete: 82.8%; Average loss: 0.5424
Iteration: 3311; Percent complete: 82.8%; Average loss: 0.5496
Iteration: 3312; Percent complete: 82.8%; Average loss:

Iteration: 3428; Percent complete: 85.7%; Average loss: 0.6715
Iteration: 3429; Percent complete: 85.7%; Average loss: 0.5559
Iteration: 3430; Percent complete: 85.8%; Average loss: 0.5792
Iteration: 3431; Percent complete: 85.8%; Average loss: 0.5876
Iteration: 3432; Percent complete: 85.8%; Average loss: 0.5862
Iteration: 3433; Percent complete: 85.8%; Average loss: 0.5787
Iteration: 3434; Percent complete: 85.9%; Average loss: 0.5499
Iteration: 3435; Percent complete: 85.9%; Average loss: 0.4893
Iteration: 3436; Percent complete: 85.9%; Average loss: 0.5403
Iteration: 3437; Percent complete: 85.9%; Average loss: 0.5273
Iteration: 3438; Percent complete: 86.0%; Average loss: 0.4953
Iteration: 3439; Percent complete: 86.0%; Average loss: 0.4228
Iteration: 3440; Percent complete: 86.0%; Average loss: 0.4604
Iteration: 3441; Percent complete: 86.0%; Average loss: 0.5484
Iteration: 3442; Percent complete: 86.1%; Average loss: 0.5386
Iteration: 3443; Percent complete: 86.1%; Average loss:

Iteration: 3559; Percent complete: 89.0%; Average loss: 0.5500
Iteration: 3560; Percent complete: 89.0%; Average loss: 0.4897
Iteration: 3561; Percent complete: 89.0%; Average loss: 0.5389
Iteration: 3562; Percent complete: 89.0%; Average loss: 0.4685
Iteration: 3563; Percent complete: 89.1%; Average loss: 0.5167
Iteration: 3564; Percent complete: 89.1%; Average loss: 0.4898
Iteration: 3565; Percent complete: 89.1%; Average loss: 0.4922
Iteration: 3566; Percent complete: 89.1%; Average loss: 0.4650
Iteration: 3567; Percent complete: 89.2%; Average loss: 0.5236
Iteration: 3568; Percent complete: 89.2%; Average loss: 0.4674
Iteration: 3569; Percent complete: 89.2%; Average loss: 0.5080
Iteration: 3570; Percent complete: 89.2%; Average loss: 0.4526
Iteration: 3571; Percent complete: 89.3%; Average loss: 0.5186
Iteration: 3572; Percent complete: 89.3%; Average loss: 0.5555
Iteration: 3573; Percent complete: 89.3%; Average loss: 0.4766
Iteration: 3574; Percent complete: 89.3%; Average loss:

Iteration: 3690; Percent complete: 92.2%; Average loss: 0.5038
Iteration: 3691; Percent complete: 92.3%; Average loss: 0.4489
Iteration: 3692; Percent complete: 92.3%; Average loss: 0.4205
Iteration: 3693; Percent complete: 92.3%; Average loss: 0.4211
Iteration: 3694; Percent complete: 92.3%; Average loss: 0.4327
Iteration: 3695; Percent complete: 92.4%; Average loss: 0.5130
Iteration: 3696; Percent complete: 92.4%; Average loss: 0.5296
Iteration: 3697; Percent complete: 92.4%; Average loss: 0.4620
Iteration: 3698; Percent complete: 92.5%; Average loss: 0.4310
Iteration: 3699; Percent complete: 92.5%; Average loss: 0.5749
Iteration: 3700; Percent complete: 92.5%; Average loss: 0.4525
Iteration: 3701; Percent complete: 92.5%; Average loss: 0.4436
Iteration: 3702; Percent complete: 92.5%; Average loss: 0.4670
Iteration: 3703; Percent complete: 92.6%; Average loss: 0.4860
Iteration: 3704; Percent complete: 92.6%; Average loss: 0.4263
Iteration: 3705; Percent complete: 92.6%; Average loss:

Iteration: 3821; Percent complete: 95.5%; Average loss: 0.4678
Iteration: 3822; Percent complete: 95.5%; Average loss: 0.4091
Iteration: 3823; Percent complete: 95.6%; Average loss: 0.3880
Iteration: 3824; Percent complete: 95.6%; Average loss: 0.4132
Iteration: 3825; Percent complete: 95.6%; Average loss: 0.3477
Iteration: 3826; Percent complete: 95.7%; Average loss: 0.4702
Iteration: 3827; Percent complete: 95.7%; Average loss: 0.3428
Iteration: 3828; Percent complete: 95.7%; Average loss: 0.3521
Iteration: 3829; Percent complete: 95.7%; Average loss: 0.4181
Iteration: 3830; Percent complete: 95.8%; Average loss: 0.4611
Iteration: 3831; Percent complete: 95.8%; Average loss: 0.5392
Iteration: 3832; Percent complete: 95.8%; Average loss: 0.4287
Iteration: 3833; Percent complete: 95.8%; Average loss: 0.4262
Iteration: 3834; Percent complete: 95.9%; Average loss: 0.4503
Iteration: 3835; Percent complete: 95.9%; Average loss: 0.3854
Iteration: 3836; Percent complete: 95.9%; Average loss:

Iteration: 3952; Percent complete: 98.8%; Average loss: 0.5043
Iteration: 3953; Percent complete: 98.8%; Average loss: 0.3444
Iteration: 3954; Percent complete: 98.9%; Average loss: 0.3689
Iteration: 3955; Percent complete: 98.9%; Average loss: 0.3674
Iteration: 3956; Percent complete: 98.9%; Average loss: 0.4126
Iteration: 3957; Percent complete: 98.9%; Average loss: 0.3708
Iteration: 3958; Percent complete: 99.0%; Average loss: 0.3055
Iteration: 3959; Percent complete: 99.0%; Average loss: 0.3872
Iteration: 3960; Percent complete: 99.0%; Average loss: 0.4583
Iteration: 3961; Percent complete: 99.0%; Average loss: 0.4019
Iteration: 3962; Percent complete: 99.1%; Average loss: 0.4040
Iteration: 3963; Percent complete: 99.1%; Average loss: 0.4487
Iteration: 3964; Percent complete: 99.1%; Average loss: 0.4743
Iteration: 3965; Percent complete: 99.1%; Average loss: 0.4025
Iteration: 3966; Percent complete: 99.2%; Average loss: 0.3652
Iteration: 3967; Percent complete: 99.2%; Average loss:

## RUN AN PLAY

In [None]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting
evaluateInput(encoder, decoder, searcher, voc)

> hello
Bot: hello is that miss smith ? ?
> no it is not
Bot: how long is the bus ride ? ?
> i do not take the bus
Bot: you d better take it out . .
> is it for ecology ?
Error: Encountered unknown word.
> how are you ?
Bot: i m great . thanks . how .
> fine
Bot: that s too bad . you ?
> do you love cats ?
Bot: i think so . . .
> do you like puppys ?
Error: Encountered unknown word.
> what do you eat for breakfast ?
Bot: i want to buy some pie . .
> do you love a woman ?
Bot: i think i love it . .
> hello
Bot: hello is that miss smith ? ?
> i like smoking
Bot: you d better take a cup of coffee .
> coffe is drug
Error: Encountered unknown word.
> coffee isdrug
Error: Encountered unknown word.
> coffee is drug
Error: Encountered unknown word.
> coffeeis bad
Error: Encountered unknown word.
> coffee is bad
Bot: you re right . in the morning .
