In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [2]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\r\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\r\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\r\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\r\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\r\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\r\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\r\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\r\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\r\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\r\n'


In [3]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [4]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...


IndexError: list index out of range

In [5]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [6]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [7]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [21]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print(batches)
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

(tensor([[ 450, 6864,  100,  239,    5],
        [ 380,    4,   47,   82,    7],
        [  12,   42,  427,  479,  467],
        [ 264,   77, 2784,   40,    6],
        [ 111,   92,  909,   45,    2],
        [  12,    7,    6, 2222,    0],
        [4331,    6,    2,    2,    0],
        [   4,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]]), tensor([9, 8, 7, 7, 5]), tensor([[  35,   25, 2411,  318,  167],
        [   5,  200,    4,  318,    4],
        [  37,   70,    2,    2,    2],
        [  53,  240,    0,    0,    0],
        [4330,    4,    0,    0,    0],
        [   4,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]]), tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 0, 0, 0, 0]], dtype=torch.uint8), 7)
input_variable: tensor([[ 450, 6864,  100,  239,    5],
        [ 380,    4,   47,   82,    7],
        [  12,   42,  427,  479,  4

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [10]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [11]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [12]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [13]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [14]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [15]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [16]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [17]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [18]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9656
Iteration: 2; Percent complete: 0.1%; Average loss: 8.8472
Iteration: 3; Percent complete: 0.1%; Average loss: 8.6749
Iteration: 4; Percent complete: 0.1%; Average loss: 8.3608
Iteration: 5; Percent complete: 0.1%; Average loss: 8.0245
Iteration: 6; Percent complete: 0.1%; Average loss: 7.4705
Iteration: 7; Percent complete: 0.2%; Average loss: 7.0121
Iteration: 8; Percent complete: 0.2%; Average loss: 6.6287
Iteration: 9; Percent complete: 0.2%; Average loss: 6.5923
Iteration: 10; Percent complete: 0.2%; Average loss: 6.5143
Iteration: 11; Percent complete: 0.3%; Average loss: 5.9821
Iteration: 12; Percent complete: 0.3%; Average loss: 5.9566
Iteration: 13; Percent complete: 0.3%; Average loss: 5.5778
Iteration: 14; Percent complete: 0.4%; Average loss: 5.3683
Iteration: 15; Percent complete: 0.4%; Average loss: 5.3332
Iteration: 16; Percent complete: 0.4%

Iteration: 135; Percent complete: 3.4%; Average loss: 4.0409
Iteration: 136; Percent complete: 3.4%; Average loss: 4.5167
Iteration: 137; Percent complete: 3.4%; Average loss: 4.4213
Iteration: 138; Percent complete: 3.5%; Average loss: 4.2293
Iteration: 139; Percent complete: 3.5%; Average loss: 4.2384
Iteration: 140; Percent complete: 3.5%; Average loss: 4.1373
Iteration: 141; Percent complete: 3.5%; Average loss: 4.2013
Iteration: 142; Percent complete: 3.5%; Average loss: 4.2476
Iteration: 143; Percent complete: 3.6%; Average loss: 4.4549
Iteration: 144; Percent complete: 3.6%; Average loss: 4.2716
Iteration: 145; Percent complete: 3.6%; Average loss: 4.2958
Iteration: 146; Percent complete: 3.6%; Average loss: 4.3969
Iteration: 147; Percent complete: 3.7%; Average loss: 4.2432
Iteration: 148; Percent complete: 3.7%; Average loss: 4.0816
Iteration: 149; Percent complete: 3.7%; Average loss: 4.5366
Iteration: 150; Percent complete: 3.8%; Average loss: 4.2596
Iteration: 151; Percent 

Iteration: 269; Percent complete: 6.7%; Average loss: 4.1801
Iteration: 270; Percent complete: 6.8%; Average loss: 4.0716
Iteration: 271; Percent complete: 6.8%; Average loss: 3.8167
Iteration: 272; Percent complete: 6.8%; Average loss: 4.1609
Iteration: 273; Percent complete: 6.8%; Average loss: 3.8253
Iteration: 274; Percent complete: 6.9%; Average loss: 3.7227
Iteration: 275; Percent complete: 6.9%; Average loss: 4.1105
Iteration: 276; Percent complete: 6.9%; Average loss: 3.8909
Iteration: 277; Percent complete: 6.9%; Average loss: 3.8964
Iteration: 278; Percent complete: 7.0%; Average loss: 3.8820
Iteration: 279; Percent complete: 7.0%; Average loss: 4.0643
Iteration: 280; Percent complete: 7.0%; Average loss: 3.8955
Iteration: 281; Percent complete: 7.0%; Average loss: 4.0127
Iteration: 282; Percent complete: 7.0%; Average loss: 3.8870
Iteration: 283; Percent complete: 7.1%; Average loss: 4.0599
Iteration: 284; Percent complete: 7.1%; Average loss: 3.9663
Iteration: 285; Percent 

Iteration: 403; Percent complete: 10.1%; Average loss: 3.8639
Iteration: 404; Percent complete: 10.1%; Average loss: 3.7011
Iteration: 405; Percent complete: 10.1%; Average loss: 3.9755
Iteration: 406; Percent complete: 10.2%; Average loss: 3.9482
Iteration: 407; Percent complete: 10.2%; Average loss: 3.9292
Iteration: 408; Percent complete: 10.2%; Average loss: 3.9398
Iteration: 409; Percent complete: 10.2%; Average loss: 3.9840
Iteration: 410; Percent complete: 10.2%; Average loss: 3.6226
Iteration: 411; Percent complete: 10.3%; Average loss: 3.7152
Iteration: 412; Percent complete: 10.3%; Average loss: 3.8582
Iteration: 413; Percent complete: 10.3%; Average loss: 3.5674
Iteration: 414; Percent complete: 10.3%; Average loss: 4.0390
Iteration: 415; Percent complete: 10.4%; Average loss: 3.6703
Iteration: 416; Percent complete: 10.4%; Average loss: 4.0953
Iteration: 417; Percent complete: 10.4%; Average loss: 3.9137
Iteration: 418; Percent complete: 10.4%; Average loss: 3.7830
Iteratio

Iteration: 535; Percent complete: 13.4%; Average loss: 3.5976
Iteration: 536; Percent complete: 13.4%; Average loss: 3.6925
Iteration: 537; Percent complete: 13.4%; Average loss: 3.5138
Iteration: 538; Percent complete: 13.5%; Average loss: 3.5557
Iteration: 539; Percent complete: 13.5%; Average loss: 3.7101
Iteration: 540; Percent complete: 13.5%; Average loss: 3.3911
Iteration: 541; Percent complete: 13.5%; Average loss: 3.6075
Iteration: 542; Percent complete: 13.6%; Average loss: 3.5280
Iteration: 543; Percent complete: 13.6%; Average loss: 3.8809
Iteration: 544; Percent complete: 13.6%; Average loss: 3.8787
Iteration: 545; Percent complete: 13.6%; Average loss: 3.7954
Iteration: 546; Percent complete: 13.7%; Average loss: 3.6904
Iteration: 547; Percent complete: 13.7%; Average loss: 3.7764
Iteration: 548; Percent complete: 13.7%; Average loss: 3.6893
Iteration: 549; Percent complete: 13.7%; Average loss: 3.8016
Iteration: 550; Percent complete: 13.8%; Average loss: 3.7322
Iteratio

Iteration: 667; Percent complete: 16.7%; Average loss: 3.6494
Iteration: 668; Percent complete: 16.7%; Average loss: 3.7747
Iteration: 669; Percent complete: 16.7%; Average loss: 3.6486
Iteration: 670; Percent complete: 16.8%; Average loss: 3.6554
Iteration: 671; Percent complete: 16.8%; Average loss: 3.7997
Iteration: 672; Percent complete: 16.8%; Average loss: 3.5799
Iteration: 673; Percent complete: 16.8%; Average loss: 3.5915
Iteration: 674; Percent complete: 16.9%; Average loss: 3.5159
Iteration: 675; Percent complete: 16.9%; Average loss: 3.7581
Iteration: 676; Percent complete: 16.9%; Average loss: 3.6789
Iteration: 677; Percent complete: 16.9%; Average loss: 3.6467
Iteration: 678; Percent complete: 17.0%; Average loss: 3.4229
Iteration: 679; Percent complete: 17.0%; Average loss: 3.6357
Iteration: 680; Percent complete: 17.0%; Average loss: 3.8363
Iteration: 681; Percent complete: 17.0%; Average loss: 3.6293
Iteration: 682; Percent complete: 17.1%; Average loss: 3.4668
Iteratio

Iteration: 799; Percent complete: 20.0%; Average loss: 3.5569
Iteration: 800; Percent complete: 20.0%; Average loss: 3.5582
Iteration: 801; Percent complete: 20.0%; Average loss: 3.5631
Iteration: 802; Percent complete: 20.1%; Average loss: 3.3235
Iteration: 803; Percent complete: 20.1%; Average loss: 3.5801
Iteration: 804; Percent complete: 20.1%; Average loss: 3.8898
Iteration: 805; Percent complete: 20.1%; Average loss: 3.4954
Iteration: 806; Percent complete: 20.2%; Average loss: 3.5904
Iteration: 807; Percent complete: 20.2%; Average loss: 3.6570
Iteration: 808; Percent complete: 20.2%; Average loss: 3.7355
Iteration: 809; Percent complete: 20.2%; Average loss: 3.6604
Iteration: 810; Percent complete: 20.2%; Average loss: 3.5675
Iteration: 811; Percent complete: 20.3%; Average loss: 3.6138
Iteration: 812; Percent complete: 20.3%; Average loss: 3.7357
Iteration: 813; Percent complete: 20.3%; Average loss: 3.7497
Iteration: 814; Percent complete: 20.3%; Average loss: 3.6268
Iteratio

Iteration: 931; Percent complete: 23.3%; Average loss: 3.4464
Iteration: 932; Percent complete: 23.3%; Average loss: 3.5497
Iteration: 933; Percent complete: 23.3%; Average loss: 3.3374
Iteration: 934; Percent complete: 23.4%; Average loss: 3.8464
Iteration: 935; Percent complete: 23.4%; Average loss: 3.4751
Iteration: 936; Percent complete: 23.4%; Average loss: 3.2948
Iteration: 937; Percent complete: 23.4%; Average loss: 3.5278
Iteration: 938; Percent complete: 23.4%; Average loss: 3.3506
Iteration: 939; Percent complete: 23.5%; Average loss: 3.5820
Iteration: 940; Percent complete: 23.5%; Average loss: 3.6757
Iteration: 941; Percent complete: 23.5%; Average loss: 3.2880
Iteration: 942; Percent complete: 23.5%; Average loss: 3.4466
Iteration: 943; Percent complete: 23.6%; Average loss: 3.3738
Iteration: 944; Percent complete: 23.6%; Average loss: 3.4551
Iteration: 945; Percent complete: 23.6%; Average loss: 3.3840
Iteration: 946; Percent complete: 23.6%; Average loss: 3.4876
Iteratio

Iteration: 1062; Percent complete: 26.6%; Average loss: 3.4828
Iteration: 1063; Percent complete: 26.6%; Average loss: 3.2974
Iteration: 1064; Percent complete: 26.6%; Average loss: 3.3650
Iteration: 1065; Percent complete: 26.6%; Average loss: 3.2836
Iteration: 1066; Percent complete: 26.7%; Average loss: 3.1759
Iteration: 1067; Percent complete: 26.7%; Average loss: 3.5020
Iteration: 1068; Percent complete: 26.7%; Average loss: 3.3529
Iteration: 1069; Percent complete: 26.7%; Average loss: 3.6554
Iteration: 1070; Percent complete: 26.8%; Average loss: 3.3855
Iteration: 1071; Percent complete: 26.8%; Average loss: 3.1295
Iteration: 1072; Percent complete: 26.8%; Average loss: 3.5396
Iteration: 1073; Percent complete: 26.8%; Average loss: 3.2533
Iteration: 1074; Percent complete: 26.9%; Average loss: 3.2621
Iteration: 1075; Percent complete: 26.9%; Average loss: 3.5055
Iteration: 1076; Percent complete: 26.9%; Average loss: 3.5204
Iteration: 1077; Percent complete: 26.9%; Average loss:

Iteration: 1192; Percent complete: 29.8%; Average loss: 3.3939
Iteration: 1193; Percent complete: 29.8%; Average loss: 3.4850
Iteration: 1194; Percent complete: 29.8%; Average loss: 3.2607
Iteration: 1195; Percent complete: 29.9%; Average loss: 3.4564
Iteration: 1196; Percent complete: 29.9%; Average loss: 3.3420
Iteration: 1197; Percent complete: 29.9%; Average loss: 3.3224
Iteration: 1198; Percent complete: 29.9%; Average loss: 3.4709
Iteration: 1199; Percent complete: 30.0%; Average loss: 3.2558
Iteration: 1200; Percent complete: 30.0%; Average loss: 3.4796
Iteration: 1201; Percent complete: 30.0%; Average loss: 3.4825
Iteration: 1202; Percent complete: 30.0%; Average loss: 3.3786
Iteration: 1203; Percent complete: 30.1%; Average loss: 3.5464
Iteration: 1204; Percent complete: 30.1%; Average loss: 3.3590
Iteration: 1205; Percent complete: 30.1%; Average loss: 3.4017
Iteration: 1206; Percent complete: 30.1%; Average loss: 3.6089
Iteration: 1207; Percent complete: 30.2%; Average loss:

Iteration: 1322; Percent complete: 33.1%; Average loss: 3.1487
Iteration: 1323; Percent complete: 33.1%; Average loss: 3.0766
Iteration: 1324; Percent complete: 33.1%; Average loss: 3.4188
Iteration: 1325; Percent complete: 33.1%; Average loss: 3.4053
Iteration: 1326; Percent complete: 33.1%; Average loss: 3.4210
Iteration: 1327; Percent complete: 33.2%; Average loss: 3.5000
Iteration: 1328; Percent complete: 33.2%; Average loss: 3.1996
Iteration: 1329; Percent complete: 33.2%; Average loss: 3.5967
Iteration: 1330; Percent complete: 33.2%; Average loss: 3.2522
Iteration: 1331; Percent complete: 33.3%; Average loss: 3.2773
Iteration: 1332; Percent complete: 33.3%; Average loss: 3.4858
Iteration: 1333; Percent complete: 33.3%; Average loss: 3.5986
Iteration: 1334; Percent complete: 33.4%; Average loss: 3.3811
Iteration: 1335; Percent complete: 33.4%; Average loss: 3.3509
Iteration: 1336; Percent complete: 33.4%; Average loss: 3.5846
Iteration: 1337; Percent complete: 33.4%; Average loss:

Iteration: 1452; Percent complete: 36.3%; Average loss: 3.4597
Iteration: 1453; Percent complete: 36.3%; Average loss: 3.4425
Iteration: 1454; Percent complete: 36.4%; Average loss: 3.4965
Iteration: 1455; Percent complete: 36.4%; Average loss: 3.2882
Iteration: 1456; Percent complete: 36.4%; Average loss: 3.3282
Iteration: 1457; Percent complete: 36.4%; Average loss: 3.4401
Iteration: 1458; Percent complete: 36.4%; Average loss: 3.3747
Iteration: 1459; Percent complete: 36.5%; Average loss: 3.4155
Iteration: 1460; Percent complete: 36.5%; Average loss: 3.3849
Iteration: 1461; Percent complete: 36.5%; Average loss: 3.1928
Iteration: 1462; Percent complete: 36.5%; Average loss: 3.3490
Iteration: 1463; Percent complete: 36.6%; Average loss: 3.4296
Iteration: 1464; Percent complete: 36.6%; Average loss: 3.5518
Iteration: 1465; Percent complete: 36.6%; Average loss: 3.3060
Iteration: 1466; Percent complete: 36.6%; Average loss: 3.5254
Iteration: 1467; Percent complete: 36.7%; Average loss:

Iteration: 1582; Percent complete: 39.6%; Average loss: 3.3167
Iteration: 1583; Percent complete: 39.6%; Average loss: 3.2379
Iteration: 1584; Percent complete: 39.6%; Average loss: 3.4734
Iteration: 1585; Percent complete: 39.6%; Average loss: 3.2151
Iteration: 1586; Percent complete: 39.6%; Average loss: 3.1602
Iteration: 1587; Percent complete: 39.7%; Average loss: 3.3538
Iteration: 1588; Percent complete: 39.7%; Average loss: 3.3135
Iteration: 1589; Percent complete: 39.7%; Average loss: 3.3840
Iteration: 1590; Percent complete: 39.8%; Average loss: 3.1992
Iteration: 1591; Percent complete: 39.8%; Average loss: 2.9038
Iteration: 1592; Percent complete: 39.8%; Average loss: 3.2627
Iteration: 1593; Percent complete: 39.8%; Average loss: 3.1407
Iteration: 1594; Percent complete: 39.9%; Average loss: 3.2188
Iteration: 1595; Percent complete: 39.9%; Average loss: 3.1834
Iteration: 1596; Percent complete: 39.9%; Average loss: 3.1566
Iteration: 1597; Percent complete: 39.9%; Average loss:

Iteration: 1712; Percent complete: 42.8%; Average loss: 3.2819
Iteration: 1713; Percent complete: 42.8%; Average loss: 3.2118
Iteration: 1714; Percent complete: 42.9%; Average loss: 3.4565
Iteration: 1715; Percent complete: 42.9%; Average loss: 3.3470
Iteration: 1716; Percent complete: 42.9%; Average loss: 3.4738
Iteration: 1717; Percent complete: 42.9%; Average loss: 3.2093
Iteration: 1718; Percent complete: 43.0%; Average loss: 3.3782
Iteration: 1719; Percent complete: 43.0%; Average loss: 3.1738
Iteration: 1720; Percent complete: 43.0%; Average loss: 3.1553
Iteration: 1721; Percent complete: 43.0%; Average loss: 3.1090
Iteration: 1722; Percent complete: 43.0%; Average loss: 3.1817
Iteration: 1723; Percent complete: 43.1%; Average loss: 3.4047
Iteration: 1724; Percent complete: 43.1%; Average loss: 3.3916
Iteration: 1725; Percent complete: 43.1%; Average loss: 3.2446
Iteration: 1726; Percent complete: 43.1%; Average loss: 3.2911
Iteration: 1727; Percent complete: 43.2%; Average loss:

Iteration: 1842; Percent complete: 46.1%; Average loss: 3.3469
Iteration: 1843; Percent complete: 46.1%; Average loss: 3.4142
Iteration: 1844; Percent complete: 46.1%; Average loss: 3.3900
Iteration: 1845; Percent complete: 46.1%; Average loss: 3.1471
Iteration: 1846; Percent complete: 46.2%; Average loss: 3.3329
Iteration: 1847; Percent complete: 46.2%; Average loss: 3.0521
Iteration: 1848; Percent complete: 46.2%; Average loss: 3.3973
Iteration: 1849; Percent complete: 46.2%; Average loss: 3.1800
Iteration: 1850; Percent complete: 46.2%; Average loss: 2.8550
Iteration: 1851; Percent complete: 46.3%; Average loss: 3.0793
Iteration: 1852; Percent complete: 46.3%; Average loss: 2.9853
Iteration: 1853; Percent complete: 46.3%; Average loss: 3.2290
Iteration: 1854; Percent complete: 46.4%; Average loss: 3.3013
Iteration: 1855; Percent complete: 46.4%; Average loss: 3.2227
Iteration: 1856; Percent complete: 46.4%; Average loss: 3.1741
Iteration: 1857; Percent complete: 46.4%; Average loss:

Iteration: 1972; Percent complete: 49.3%; Average loss: 3.3394
Iteration: 1973; Percent complete: 49.3%; Average loss: 3.0612
Iteration: 1974; Percent complete: 49.4%; Average loss: 3.0705
Iteration: 1975; Percent complete: 49.4%; Average loss: 2.9886
Iteration: 1976; Percent complete: 49.4%; Average loss: 3.3119
Iteration: 1977; Percent complete: 49.4%; Average loss: 2.9873
Iteration: 1978; Percent complete: 49.5%; Average loss: 2.9104
Iteration: 1979; Percent complete: 49.5%; Average loss: 2.8608
Iteration: 1980; Percent complete: 49.5%; Average loss: 3.0454
Iteration: 1981; Percent complete: 49.5%; Average loss: 3.1780
Iteration: 1982; Percent complete: 49.5%; Average loss: 3.0574
Iteration: 1983; Percent complete: 49.6%; Average loss: 3.2709
Iteration: 1984; Percent complete: 49.6%; Average loss: 3.3109
Iteration: 1985; Percent complete: 49.6%; Average loss: 3.2721
Iteration: 1986; Percent complete: 49.6%; Average loss: 3.0064
Iteration: 1987; Percent complete: 49.7%; Average loss:

Iteration: 2102; Percent complete: 52.5%; Average loss: 3.1959
Iteration: 2103; Percent complete: 52.6%; Average loss: 3.1323
Iteration: 2104; Percent complete: 52.6%; Average loss: 2.9287
Iteration: 2105; Percent complete: 52.6%; Average loss: 3.2700
Iteration: 2106; Percent complete: 52.6%; Average loss: 3.4533
Iteration: 2107; Percent complete: 52.7%; Average loss: 3.2739
Iteration: 2108; Percent complete: 52.7%; Average loss: 2.9336
Iteration: 2109; Percent complete: 52.7%; Average loss: 2.9961
Iteration: 2110; Percent complete: 52.8%; Average loss: 3.2872
Iteration: 2111; Percent complete: 52.8%; Average loss: 3.0462
Iteration: 2112; Percent complete: 52.8%; Average loss: 3.0274
Iteration: 2113; Percent complete: 52.8%; Average loss: 3.3355
Iteration: 2114; Percent complete: 52.8%; Average loss: 3.0454
Iteration: 2115; Percent complete: 52.9%; Average loss: 3.1181
Iteration: 2116; Percent complete: 52.9%; Average loss: 3.0993
Iteration: 2117; Percent complete: 52.9%; Average loss:

Iteration: 2232; Percent complete: 55.8%; Average loss: 3.0647
Iteration: 2233; Percent complete: 55.8%; Average loss: 3.3635
Iteration: 2234; Percent complete: 55.9%; Average loss: 2.9723
Iteration: 2235; Percent complete: 55.9%; Average loss: 2.9529
Iteration: 2236; Percent complete: 55.9%; Average loss: 3.2244
Iteration: 2237; Percent complete: 55.9%; Average loss: 2.9170
Iteration: 2238; Percent complete: 56.0%; Average loss: 2.9250
Iteration: 2239; Percent complete: 56.0%; Average loss: 3.2486
Iteration: 2240; Percent complete: 56.0%; Average loss: 3.0659
Iteration: 2241; Percent complete: 56.0%; Average loss: 2.8901
Iteration: 2242; Percent complete: 56.0%; Average loss: 2.9334
Iteration: 2243; Percent complete: 56.1%; Average loss: 2.9673
Iteration: 2244; Percent complete: 56.1%; Average loss: 3.0676
Iteration: 2245; Percent complete: 56.1%; Average loss: 3.1115
Iteration: 2246; Percent complete: 56.1%; Average loss: 3.1872
Iteration: 2247; Percent complete: 56.2%; Average loss:

Iteration: 2362; Percent complete: 59.1%; Average loss: 2.9774
Iteration: 2363; Percent complete: 59.1%; Average loss: 3.0623
Iteration: 2364; Percent complete: 59.1%; Average loss: 2.9872
Iteration: 2365; Percent complete: 59.1%; Average loss: 3.0753
Iteration: 2366; Percent complete: 59.2%; Average loss: 2.9595
Iteration: 2367; Percent complete: 59.2%; Average loss: 3.0294
Iteration: 2368; Percent complete: 59.2%; Average loss: 2.9742
Iteration: 2369; Percent complete: 59.2%; Average loss: 3.1194
Iteration: 2370; Percent complete: 59.2%; Average loss: 3.1185
Iteration: 2371; Percent complete: 59.3%; Average loss: 3.1674
Iteration: 2372; Percent complete: 59.3%; Average loss: 3.3124
Iteration: 2373; Percent complete: 59.3%; Average loss: 3.2195
Iteration: 2374; Percent complete: 59.4%; Average loss: 3.0266
Iteration: 2375; Percent complete: 59.4%; Average loss: 3.1692
Iteration: 2376; Percent complete: 59.4%; Average loss: 3.0283
Iteration: 2377; Percent complete: 59.4%; Average loss:

Iteration: 2492; Percent complete: 62.3%; Average loss: 3.2254
Iteration: 2493; Percent complete: 62.3%; Average loss: 3.1082
Iteration: 2494; Percent complete: 62.4%; Average loss: 2.9122
Iteration: 2495; Percent complete: 62.4%; Average loss: 2.9752
Iteration: 2496; Percent complete: 62.4%; Average loss: 3.3322
Iteration: 2497; Percent complete: 62.4%; Average loss: 3.1241
Iteration: 2498; Percent complete: 62.5%; Average loss: 2.8540
Iteration: 2499; Percent complete: 62.5%; Average loss: 3.0682
Iteration: 2500; Percent complete: 62.5%; Average loss: 2.9519
Iteration: 2501; Percent complete: 62.5%; Average loss: 3.0551
Iteration: 2502; Percent complete: 62.5%; Average loss: 3.1839
Iteration: 2503; Percent complete: 62.6%; Average loss: 2.8101
Iteration: 2504; Percent complete: 62.6%; Average loss: 3.0774
Iteration: 2505; Percent complete: 62.6%; Average loss: 3.0285
Iteration: 2506; Percent complete: 62.6%; Average loss: 3.2162
Iteration: 2507; Percent complete: 62.7%; Average loss:

Iteration: 2622; Percent complete: 65.5%; Average loss: 3.1850
Iteration: 2623; Percent complete: 65.6%; Average loss: 3.1722
Iteration: 2624; Percent complete: 65.6%; Average loss: 3.4252
Iteration: 2625; Percent complete: 65.6%; Average loss: 3.3544
Iteration: 2626; Percent complete: 65.6%; Average loss: 2.9467
Iteration: 2627; Percent complete: 65.7%; Average loss: 3.1121
Iteration: 2628; Percent complete: 65.7%; Average loss: 2.9173
Iteration: 2629; Percent complete: 65.7%; Average loss: 3.0894
Iteration: 2630; Percent complete: 65.8%; Average loss: 3.0972
Iteration: 2631; Percent complete: 65.8%; Average loss: 2.7473
Iteration: 2632; Percent complete: 65.8%; Average loss: 2.9855
Iteration: 2633; Percent complete: 65.8%; Average loss: 3.0329
Iteration: 2634; Percent complete: 65.8%; Average loss: 2.8869
Iteration: 2635; Percent complete: 65.9%; Average loss: 2.8216
Iteration: 2636; Percent complete: 65.9%; Average loss: 3.0789
Iteration: 2637; Percent complete: 65.9%; Average loss:

Iteration: 2752; Percent complete: 68.8%; Average loss: 2.9660
Iteration: 2753; Percent complete: 68.8%; Average loss: 3.0356
Iteration: 2754; Percent complete: 68.8%; Average loss: 2.7797
Iteration: 2755; Percent complete: 68.9%; Average loss: 3.1472
Iteration: 2756; Percent complete: 68.9%; Average loss: 3.0678
Iteration: 2757; Percent complete: 68.9%; Average loss: 2.8506
Iteration: 2758; Percent complete: 69.0%; Average loss: 2.8600
Iteration: 2759; Percent complete: 69.0%; Average loss: 2.7884
Iteration: 2760; Percent complete: 69.0%; Average loss: 2.8202
Iteration: 2761; Percent complete: 69.0%; Average loss: 3.0616
Iteration: 2762; Percent complete: 69.0%; Average loss: 2.8477
Iteration: 2763; Percent complete: 69.1%; Average loss: 2.8211
Iteration: 2764; Percent complete: 69.1%; Average loss: 2.8704
Iteration: 2765; Percent complete: 69.1%; Average loss: 2.7501
Iteration: 2766; Percent complete: 69.2%; Average loss: 3.0534
Iteration: 2767; Percent complete: 69.2%; Average loss:

Iteration: 2882; Percent complete: 72.0%; Average loss: 2.9190
Iteration: 2883; Percent complete: 72.1%; Average loss: 2.9986
Iteration: 2884; Percent complete: 72.1%; Average loss: 2.7099
Iteration: 2885; Percent complete: 72.1%; Average loss: 3.1343
Iteration: 2886; Percent complete: 72.2%; Average loss: 2.8536
Iteration: 2887; Percent complete: 72.2%; Average loss: 2.9391
Iteration: 2888; Percent complete: 72.2%; Average loss: 3.1116
Iteration: 2889; Percent complete: 72.2%; Average loss: 2.9399
Iteration: 2890; Percent complete: 72.2%; Average loss: 2.8000
Iteration: 2891; Percent complete: 72.3%; Average loss: 3.2229
Iteration: 2892; Percent complete: 72.3%; Average loss: 2.6412
Iteration: 2893; Percent complete: 72.3%; Average loss: 2.9106
Iteration: 2894; Percent complete: 72.4%; Average loss: 3.0768
Iteration: 2895; Percent complete: 72.4%; Average loss: 2.8732
Iteration: 2896; Percent complete: 72.4%; Average loss: 2.7778
Iteration: 2897; Percent complete: 72.4%; Average loss:

Iteration: 3012; Percent complete: 75.3%; Average loss: 2.9772
Iteration: 3013; Percent complete: 75.3%; Average loss: 2.9775
Iteration: 3014; Percent complete: 75.3%; Average loss: 2.7774
Iteration: 3015; Percent complete: 75.4%; Average loss: 2.7635
Iteration: 3016; Percent complete: 75.4%; Average loss: 2.8773
Iteration: 3017; Percent complete: 75.4%; Average loss: 2.8198
Iteration: 3018; Percent complete: 75.4%; Average loss: 2.8512
Iteration: 3019; Percent complete: 75.5%; Average loss: 2.8638
Iteration: 3020; Percent complete: 75.5%; Average loss: 2.7799
Iteration: 3021; Percent complete: 75.5%; Average loss: 2.6483
Iteration: 3022; Percent complete: 75.5%; Average loss: 2.7663
Iteration: 3023; Percent complete: 75.6%; Average loss: 3.0319
Iteration: 3024; Percent complete: 75.6%; Average loss: 2.7844
Iteration: 3025; Percent complete: 75.6%; Average loss: 2.7896
Iteration: 3026; Percent complete: 75.6%; Average loss: 2.8647
Iteration: 3027; Percent complete: 75.7%; Average loss:

Iteration: 3142; Percent complete: 78.5%; Average loss: 2.9260
Iteration: 3143; Percent complete: 78.6%; Average loss: 2.9763
Iteration: 3144; Percent complete: 78.6%; Average loss: 2.6049
Iteration: 3145; Percent complete: 78.6%; Average loss: 2.9413
Iteration: 3146; Percent complete: 78.6%; Average loss: 2.6873
Iteration: 3147; Percent complete: 78.7%; Average loss: 2.8068
Iteration: 3148; Percent complete: 78.7%; Average loss: 2.7990
Iteration: 3149; Percent complete: 78.7%; Average loss: 2.8992
Iteration: 3150; Percent complete: 78.8%; Average loss: 2.8666
Iteration: 3151; Percent complete: 78.8%; Average loss: 3.2813
Iteration: 3152; Percent complete: 78.8%; Average loss: 2.9246
Iteration: 3153; Percent complete: 78.8%; Average loss: 2.7235
Iteration: 3154; Percent complete: 78.8%; Average loss: 2.8587
Iteration: 3155; Percent complete: 78.9%; Average loss: 3.0159
Iteration: 3156; Percent complete: 78.9%; Average loss: 2.7566
Iteration: 3157; Percent complete: 78.9%; Average loss:

Iteration: 3272; Percent complete: 81.8%; Average loss: 2.7618
Iteration: 3273; Percent complete: 81.8%; Average loss: 2.8642
Iteration: 3274; Percent complete: 81.8%; Average loss: 2.7862
Iteration: 3275; Percent complete: 81.9%; Average loss: 2.8478
Iteration: 3276; Percent complete: 81.9%; Average loss: 3.0523
Iteration: 3277; Percent complete: 81.9%; Average loss: 2.4856
Iteration: 3278; Percent complete: 82.0%; Average loss: 2.9132
Iteration: 3279; Percent complete: 82.0%; Average loss: 2.7394
Iteration: 3280; Percent complete: 82.0%; Average loss: 2.6238
Iteration: 3281; Percent complete: 82.0%; Average loss: 2.7640
Iteration: 3282; Percent complete: 82.0%; Average loss: 2.8723
Iteration: 3283; Percent complete: 82.1%; Average loss: 2.8119
Iteration: 3284; Percent complete: 82.1%; Average loss: 2.7201
Iteration: 3285; Percent complete: 82.1%; Average loss: 2.9276
Iteration: 3286; Percent complete: 82.2%; Average loss: 2.8068
Iteration: 3287; Percent complete: 82.2%; Average loss:

Iteration: 3402; Percent complete: 85.0%; Average loss: 2.7259
Iteration: 3403; Percent complete: 85.1%; Average loss: 2.8753
Iteration: 3404; Percent complete: 85.1%; Average loss: 2.6321
Iteration: 3405; Percent complete: 85.1%; Average loss: 2.6674
Iteration: 3406; Percent complete: 85.2%; Average loss: 2.8712
Iteration: 3407; Percent complete: 85.2%; Average loss: 2.6984
Iteration: 3408; Percent complete: 85.2%; Average loss: 2.9985
Iteration: 3409; Percent complete: 85.2%; Average loss: 2.8046
Iteration: 3410; Percent complete: 85.2%; Average loss: 2.5563
Iteration: 3411; Percent complete: 85.3%; Average loss: 2.6927
Iteration: 3412; Percent complete: 85.3%; Average loss: 2.6678
Iteration: 3413; Percent complete: 85.3%; Average loss: 2.9387
Iteration: 3414; Percent complete: 85.4%; Average loss: 2.6105
Iteration: 3415; Percent complete: 85.4%; Average loss: 2.6689
Iteration: 3416; Percent complete: 85.4%; Average loss: 2.9117
Iteration: 3417; Percent complete: 85.4%; Average loss:

Iteration: 3532; Percent complete: 88.3%; Average loss: 2.6656
Iteration: 3533; Percent complete: 88.3%; Average loss: 2.7582
Iteration: 3534; Percent complete: 88.3%; Average loss: 2.8810
Iteration: 3535; Percent complete: 88.4%; Average loss: 2.7022
Iteration: 3536; Percent complete: 88.4%; Average loss: 2.8435
Iteration: 3537; Percent complete: 88.4%; Average loss: 2.5553
Iteration: 3538; Percent complete: 88.4%; Average loss: 2.8079
Iteration: 3539; Percent complete: 88.5%; Average loss: 2.9520
Iteration: 3540; Percent complete: 88.5%; Average loss: 2.8451
Iteration: 3541; Percent complete: 88.5%; Average loss: 2.8444
Iteration: 3542; Percent complete: 88.5%; Average loss: 2.5785
Iteration: 3543; Percent complete: 88.6%; Average loss: 2.9163
Iteration: 3544; Percent complete: 88.6%; Average loss: 2.7142
Iteration: 3545; Percent complete: 88.6%; Average loss: 2.9006
Iteration: 3546; Percent complete: 88.6%; Average loss: 2.6988
Iteration: 3547; Percent complete: 88.7%; Average loss:

Iteration: 3662; Percent complete: 91.5%; Average loss: 2.6141
Iteration: 3663; Percent complete: 91.6%; Average loss: 2.9334
Iteration: 3664; Percent complete: 91.6%; Average loss: 2.7838
Iteration: 3665; Percent complete: 91.6%; Average loss: 2.8499
Iteration: 3666; Percent complete: 91.6%; Average loss: 2.8398
Iteration: 3667; Percent complete: 91.7%; Average loss: 2.6392
Iteration: 3668; Percent complete: 91.7%; Average loss: 2.8593
Iteration: 3669; Percent complete: 91.7%; Average loss: 2.5512
Iteration: 3670; Percent complete: 91.8%; Average loss: 2.6822
Iteration: 3671; Percent complete: 91.8%; Average loss: 2.6742
Iteration: 3672; Percent complete: 91.8%; Average loss: 2.6610
Iteration: 3673; Percent complete: 91.8%; Average loss: 2.6997
Iteration: 3674; Percent complete: 91.8%; Average loss: 2.6330
Iteration: 3675; Percent complete: 91.9%; Average loss: 2.6030
Iteration: 3676; Percent complete: 91.9%; Average loss: 2.7809
Iteration: 3677; Percent complete: 91.9%; Average loss:

Iteration: 3792; Percent complete: 94.8%; Average loss: 2.7839
Iteration: 3793; Percent complete: 94.8%; Average loss: 2.7443
Iteration: 3794; Percent complete: 94.8%; Average loss: 2.8033
Iteration: 3795; Percent complete: 94.9%; Average loss: 2.6578
Iteration: 3796; Percent complete: 94.9%; Average loss: 2.8698
Iteration: 3797; Percent complete: 94.9%; Average loss: 2.5675
Iteration: 3798; Percent complete: 95.0%; Average loss: 2.5393
Iteration: 3799; Percent complete: 95.0%; Average loss: 2.8083
Iteration: 3800; Percent complete: 95.0%; Average loss: 2.6373
Iteration: 3801; Percent complete: 95.0%; Average loss: 2.9051
Iteration: 3802; Percent complete: 95.0%; Average loss: 2.8406
Iteration: 3803; Percent complete: 95.1%; Average loss: 2.7032
Iteration: 3804; Percent complete: 95.1%; Average loss: 2.6596
Iteration: 3805; Percent complete: 95.1%; Average loss: 2.5755
Iteration: 3806; Percent complete: 95.2%; Average loss: 2.6886
Iteration: 3807; Percent complete: 95.2%; Average loss:

Iteration: 3922; Percent complete: 98.0%; Average loss: 2.8825
Iteration: 3923; Percent complete: 98.1%; Average loss: 2.7625
Iteration: 3924; Percent complete: 98.1%; Average loss: 2.5693
Iteration: 3925; Percent complete: 98.1%; Average loss: 2.7525
Iteration: 3926; Percent complete: 98.2%; Average loss: 2.6971
Iteration: 3927; Percent complete: 98.2%; Average loss: 2.6311
Iteration: 3928; Percent complete: 98.2%; Average loss: 2.7837
Iteration: 3929; Percent complete: 98.2%; Average loss: 2.5881
Iteration: 3930; Percent complete: 98.2%; Average loss: 2.4270
Iteration: 3931; Percent complete: 98.3%; Average loss: 2.5573
Iteration: 3932; Percent complete: 98.3%; Average loss: 2.5283
Iteration: 3933; Percent complete: 98.3%; Average loss: 2.4112
Iteration: 3934; Percent complete: 98.4%; Average loss: 2.6141
Iteration: 3935; Percent complete: 98.4%; Average loss: 2.4613
Iteration: 3936; Percent complete: 98.4%; Average loss: 2.5821
Iteration: 3937; Percent complete: 98.4%; Average loss:

In [20]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> hello
Bot: hello . i m sorry .
> what is your name
Bot: i don t know . i
> why
Bot: because i m not asking you to .
> you are so funny
Bot: i m fine . i m sorry .
> how about the weather
Bot: i know . i know .
> you know?
Bot: yeah . i know . ?
> Give me a song
Bot: you re not gonna make it all right ?
> How about the weather
Bot: i know . i know .
> You knwo:
Error: Encountered unknown word.
> dsas
Error: Encountered unknown word.
> ok
Bot: i ll be right back .
> what is that
Bot: what is it ? i m sorry .
> are you happyÉ
Error: Encountered unknown word.
> are you happy?
Bot: i m fine . i m happy .
> ds


KeyboardInterrupt: 

In [None]:
print(1)