# ChatBot Tutorial

* Creation of a chat bot using sequence to sequence models
* Will train a simple chatbot using movie scripts from the Cornell Movie-Dialogs Corpus.

#### Steps

   * Handle loading and preprocessing of Cornell Movie-Dialogs Corpus dataset
   * Implement a sequence-to-sequence model with Luong attention mechanism(s)
   * Jointly train encoder and decoder models using mini-batches
   * Implement greedy-search decoding module
   * Interact with trained chatbot

In [48]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

In [49]:
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


In [50]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

#### Load and Process Data

DataSet:  
   * 220,579 conversational exchanges between 10,292 pairs of movie characters
   * 9,035 characters from 617 movies
   * 304,713 total utterances

In [51]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join('data', corpus_name)

def printLines(file, n = 10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)
        
printLines(os.path.join(corpus, 'movie_lines.txt'))        

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [52]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            #extract different fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
                lines[lineObj['lineID']] = lineObj
    return lines

In [53]:
# group lines into conversations
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            lineIds = eval(convObj["utteranceIDs"])
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
        return conversations    

In [54]:
#extract pairs of sentances from conversation (duplicates)
def extractSentancePairs(conversations):
    qa_pairs = []
    #iterate through conversations
    for conversation in conversations:
        #iterate through lines of conversation
        for i in range(len(conversation["lines"]) - 1):
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i + 1]["text"].strip()
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs
        

In [55]:
#path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))

lines = {}
conversations = []

MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
print("\n Processing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\n Loading Conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"), lines, MOVIE_CONVERSATIONS_FIELDS)

print("\n Writing newly formatted file...")
with open(datafile, 'w', encoding = 'utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter = delimiter, lineterminator = "\n")
    for pair in extractSentancePairs(conversations):
        writer.writerow(pair)

print("\n Printing sample lines... ")
printLines(datafile)


 Processing corpus...

 Loading Conversations...

 Writing newly formatted file...

 Printing sample lines... 
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can

#### Load and Trim Data

Voc class maps words to indexes

In [56]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

In [57]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [58]:
MAX_LENGTH = 10

# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427

def unicodeToASCII(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToASCII(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("\n Reading line... ")
    lines = open(datafile, encoding = "utf-8").read().strip().split('\n')
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc,pairs

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)

# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)


Start preparing training data ...

 Reading line... 
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [59]:
# Trim out rarely used words to make the set searched over smaller 

In [99]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

Trimmed from 58043 pairs to 53165, 0.9160 of total


In [100]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[ 122,  101,   25,  750,   25],
        [   5,  258,  200,  147,   94],
        [  92,  147,  483,   92,  117],
        [   7,    7, 2748,    7,    4],
        [ 123,   92,    4,    6,    2],
        [   6,    4,    2,    2,    0],
        [   2,    2,    0,    0,    0]])
lengths: tensor([7, 7, 6, 6, 5])
target_variable: tensor([[1034,  122,    8, 1082,   25],
        [  36,    7,    7,   56,  809],
        [  14,  215,  534,   12,    7],
        [   4,    6,  208, 1295,   94],
        [   4,    2, 5543,    4,    4],
        [   4,    0,  177,  158,    2],
        [   2,    0,    7,  935,    0],
        [   0,    0,    6,  159,    0],
        [   0,    0,    2,    4,    0],
        [   0,    0,    0,    2,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 0, 1, 1, 1],
        [1, 0, 1, 1, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 0, 1, 

In [101]:
## Sequence to sequence model
## Take in variable length input and return variable length output using FIXED model

In [102]:
# Use encoder to convert all variable length inputs to a fixed size to then make a prediction on

In [103]:
## Encoder: Two RNN's, one fed input in sequential order, one fed input in reverse order

In [104]:
## INPUTS: input_seq : batch of sentences, shape = (max_length, batch_size)
##         input_lengths : length of each sentence
##         hidden: the hidden state

In [105]:
## OUTPUTS: outputs: output features from last layer 
##          hidden: update hidden state

In [106]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [107]:
## Decoder generates words until it has reached an EOS_token
## “attention mechanism” - Similar to the ideas behind a cache, contains important information in a seperate location

In [108]:
## Attention Class to calculate and apply the attention to our tensors

In [109]:
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [110]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [111]:
# Training Mask procedure
# Masked Loss

In [112]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

Method Order:
1. Forward pass entire input batch through encoder.
2. Initialize decoder inputs as SOS_token, and hidden state as the encoder’s final hidden state.
* Forward input batch sequence through decoder one time step at a time.
* If teacher forcing: set next decoder input as the current target; else: set next decoder input as current decoder output.
* Calculate and accumulate loss.
* Perform backpropagation.
* Clip gradients.
* Update encoder and decoder model parameters.

In [113]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [114]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [115]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [116]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

In [117]:
def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [118]:
## Run Model ##

In [119]:
# Configure model
model_name = 'cb_model'
attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

# set checkpoint for future use
loadFilename = None
checkpoint_iter = 4000

# For future use, load file if file name is provided
if loadFilename:
    checkpoint = torch.load(loadFilename)
    # If going from GPU to CPU
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']
    
print('Building encoder and decoder ...')

embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [120]:
# Run the training 

In [None]:
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

encoder.train()
decoder.train()

# initialize optimizers
print("Building Optimizers... ")
encoder_optimizer = optim.Adam(encoder.parameters(), lr = learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr = learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)
    
# Run training iterations
print("Starting Training... ")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building Optimizers... 
Starting Training... 
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9645
Iteration: 2; Percent complete: 0.1%; Average loss: 8.8450
Iteration: 3; Percent complete: 0.1%; Average loss: 8.6219
Iteration: 4; Percent complete: 0.1%; Average loss: 8.3246
Iteration: 5; Percent complete: 0.1%; Average loss: 7.9387
Iteration: 6; Percent complete: 0.1%; Average loss: 7.3928
Iteration: 7; Percent complete: 0.2%; Average loss: 6.7708
Iteration: 8; Percent complete: 0.2%; Average loss: 6.7035
Iteration: 9; Percent complete: 0.2%; Average loss: 6.9538
Iteration: 10; Percent complete: 0.2%; Average loss: 6.6883
Iteration: 11; Percent complete: 0.3%; Average loss: 5.9951
Iteration: 12; Percent complete: 0.3%; Average loss: 5.9433
Iteration: 13; Percent complete: 0.3%; Average loss: 5.8848
Iteration: 14; Percent complete: 0.4%; Average loss: 5.5489
Iteration: 15; Percent complete: 0.4%; Average loss: 5.5101
Iteration: 16; Percent complete: 0

Iteration: 136; Percent complete: 3.4%; Average loss: 4.4973
Iteration: 137; Percent complete: 3.4%; Average loss: 4.2580
Iteration: 138; Percent complete: 3.5%; Average loss: 4.2419
Iteration: 139; Percent complete: 3.5%; Average loss: 4.1750
Iteration: 140; Percent complete: 3.5%; Average loss: 3.9232
Iteration: 141; Percent complete: 3.5%; Average loss: 4.3293
Iteration: 142; Percent complete: 3.5%; Average loss: 4.1685
Iteration: 143; Percent complete: 3.6%; Average loss: 4.1015
Iteration: 144; Percent complete: 3.6%; Average loss: 4.1490
Iteration: 145; Percent complete: 3.6%; Average loss: 4.1190
Iteration: 146; Percent complete: 3.6%; Average loss: 4.1611
Iteration: 147; Percent complete: 3.7%; Average loss: 4.3995
Iteration: 148; Percent complete: 3.7%; Average loss: 4.1336
Iteration: 149; Percent complete: 3.7%; Average loss: 4.3831
Iteration: 150; Percent complete: 3.8%; Average loss: 4.2324
Iteration: 151; Percent complete: 3.8%; Average loss: 4.3074
Iteration: 152; Percent 

Iteration: 271; Percent complete: 6.8%; Average loss: 3.9438
Iteration: 272; Percent complete: 6.8%; Average loss: 3.9430
Iteration: 273; Percent complete: 6.8%; Average loss: 3.8690
Iteration: 274; Percent complete: 6.9%; Average loss: 4.2129
Iteration: 275; Percent complete: 6.9%; Average loss: 4.0096
Iteration: 276; Percent complete: 6.9%; Average loss: 3.9345
Iteration: 277; Percent complete: 6.9%; Average loss: 3.9623
Iteration: 278; Percent complete: 7.0%; Average loss: 3.8221
Iteration: 279; Percent complete: 7.0%; Average loss: 3.9287
Iteration: 280; Percent complete: 7.0%; Average loss: 3.7251
Iteration: 281; Percent complete: 7.0%; Average loss: 3.9734
Iteration: 282; Percent complete: 7.0%; Average loss: 3.8388
Iteration: 283; Percent complete: 7.1%; Average loss: 3.8065
Iteration: 284; Percent complete: 7.1%; Average loss: 4.1367
Iteration: 285; Percent complete: 7.1%; Average loss: 3.7766
Iteration: 286; Percent complete: 7.1%; Average loss: 3.8187
Iteration: 287; Percent 

Iteration: 406; Percent complete: 10.2%; Average loss: 3.6729
Iteration: 407; Percent complete: 10.2%; Average loss: 3.7606
Iteration: 408; Percent complete: 10.2%; Average loss: 3.6380
Iteration: 409; Percent complete: 10.2%; Average loss: 3.5041
Iteration: 410; Percent complete: 10.2%; Average loss: 3.5870
Iteration: 411; Percent complete: 10.3%; Average loss: 3.7007
Iteration: 412; Percent complete: 10.3%; Average loss: 3.9809
Iteration: 413; Percent complete: 10.3%; Average loss: 3.9060
Iteration: 414; Percent complete: 10.3%; Average loss: 3.9203
Iteration: 415; Percent complete: 10.4%; Average loss: 3.7218
Iteration: 416; Percent complete: 10.4%; Average loss: 3.7656
Iteration: 417; Percent complete: 10.4%; Average loss: 3.6221
Iteration: 418; Percent complete: 10.4%; Average loss: 3.8230
Iteration: 419; Percent complete: 10.5%; Average loss: 3.6876
Iteration: 420; Percent complete: 10.5%; Average loss: 3.5350
Iteration: 421; Percent complete: 10.5%; Average loss: 3.7141
Iteratio

Iteration: 539; Percent complete: 13.5%; Average loss: 3.6762
Iteration: 540; Percent complete: 13.5%; Average loss: 3.7923
Iteration: 541; Percent complete: 13.5%; Average loss: 3.6626
Iteration: 542; Percent complete: 13.6%; Average loss: 3.7521
Iteration: 543; Percent complete: 13.6%; Average loss: 3.7243
Iteration: 544; Percent complete: 13.6%; Average loss: 3.9216
Iteration: 545; Percent complete: 13.6%; Average loss: 3.5903
Iteration: 546; Percent complete: 13.7%; Average loss: 3.7235
Iteration: 547; Percent complete: 13.7%; Average loss: 3.6219
Iteration: 548; Percent complete: 13.7%; Average loss: 3.7909
Iteration: 549; Percent complete: 13.7%; Average loss: 3.8004
Iteration: 550; Percent complete: 13.8%; Average loss: 3.6590
Iteration: 551; Percent complete: 13.8%; Average loss: 3.7380
Iteration: 552; Percent complete: 13.8%; Average loss: 3.7032
Iteration: 553; Percent complete: 13.8%; Average loss: 3.8119
Iteration: 554; Percent complete: 13.9%; Average loss: 3.8219
Iteratio

Iteration: 672; Percent complete: 16.8%; Average loss: 3.8773
Iteration: 673; Percent complete: 16.8%; Average loss: 3.5432
Iteration: 674; Percent complete: 16.9%; Average loss: 3.7267
Iteration: 675; Percent complete: 16.9%; Average loss: 3.4267
Iteration: 676; Percent complete: 16.9%; Average loss: 3.4197
Iteration: 677; Percent complete: 16.9%; Average loss: 3.6546
Iteration: 678; Percent complete: 17.0%; Average loss: 3.5094
Iteration: 679; Percent complete: 17.0%; Average loss: 3.7377
Iteration: 680; Percent complete: 17.0%; Average loss: 3.7665
Iteration: 681; Percent complete: 17.0%; Average loss: 3.6848
Iteration: 682; Percent complete: 17.1%; Average loss: 3.7678
Iteration: 683; Percent complete: 17.1%; Average loss: 3.4809
Iteration: 684; Percent complete: 17.1%; Average loss: 3.6471
Iteration: 685; Percent complete: 17.1%; Average loss: 3.7088
Iteration: 686; Percent complete: 17.2%; Average loss: 3.6506
Iteration: 687; Percent complete: 17.2%; Average loss: 3.5096
Iteratio

Iteration: 805; Percent complete: 20.1%; Average loss: 3.7106
Iteration: 806; Percent complete: 20.2%; Average loss: 3.3394
Iteration: 807; Percent complete: 20.2%; Average loss: 3.4323
Iteration: 808; Percent complete: 20.2%; Average loss: 3.7448
Iteration: 809; Percent complete: 20.2%; Average loss: 3.5854
Iteration: 810; Percent complete: 20.2%; Average loss: 3.3565
Iteration: 811; Percent complete: 20.3%; Average loss: 3.3611
Iteration: 812; Percent complete: 20.3%; Average loss: 3.4395
Iteration: 813; Percent complete: 20.3%; Average loss: 3.6045
Iteration: 814; Percent complete: 20.3%; Average loss: 3.7111
Iteration: 815; Percent complete: 20.4%; Average loss: 3.7452
Iteration: 816; Percent complete: 20.4%; Average loss: 3.8972
Iteration: 817; Percent complete: 20.4%; Average loss: 3.6490
Iteration: 818; Percent complete: 20.4%; Average loss: 3.6857
Iteration: 819; Percent complete: 20.5%; Average loss: 3.6907
Iteration: 820; Percent complete: 20.5%; Average loss: 3.4038
Iteratio

Iteration: 938; Percent complete: 23.4%; Average loss: 3.6514
Iteration: 939; Percent complete: 23.5%; Average loss: 3.4505
Iteration: 940; Percent complete: 23.5%; Average loss: 3.5475
Iteration: 941; Percent complete: 23.5%; Average loss: 3.4966
Iteration: 942; Percent complete: 23.5%; Average loss: 3.5979
Iteration: 943; Percent complete: 23.6%; Average loss: 3.5034
Iteration: 944; Percent complete: 23.6%; Average loss: 3.3432
Iteration: 945; Percent complete: 23.6%; Average loss: 3.7729
Iteration: 946; Percent complete: 23.6%; Average loss: 3.4787
Iteration: 947; Percent complete: 23.7%; Average loss: 3.3357
Iteration: 948; Percent complete: 23.7%; Average loss: 3.4095
Iteration: 949; Percent complete: 23.7%; Average loss: 3.5556
Iteration: 950; Percent complete: 23.8%; Average loss: 3.3919
Iteration: 951; Percent complete: 23.8%; Average loss: 3.6416
Iteration: 952; Percent complete: 23.8%; Average loss: 3.4962
Iteration: 953; Percent complete: 23.8%; Average loss: 3.3142
Iteratio