## Preparations

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

## Load & Preprocess Data 

Let's have a look to our dialydialog dataset : https://www.aclweb.org/anthology/I17-1099/

In [2]:
corpus_name = "dailydialog"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'r', encoding="utf-8") as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "dialogues_text.txt"))

The kitchen stinks . __eou__ I'll throw out the garbage . __eou__

So Dick , how about getting some coffee for tonight ? __eou__ Coffee ? I don ’ t honestly like that kind of stuff . __eou__ Come on , you can at least try a little , besides your cigarette . __eou__ What ’ s wrong with that ? Cigarette is the thing I go crazy for . __eou__ Not for me , Dick . __eou__

Are things still going badly with your houseguest ? __eou__ Getting worse . Now he ’ s eating me out of house and home . I ’ Ve tried talking to him but it all goes in one ear and out the other . He makes himself at home , which is fine . But what really gets me is that yesterday he walked into the living room in the raw and I had company over ! That was the last straw . __eou__ Leo , I really think you ’ re beating around the bush with this guy . I know he used to be your best friend in college , but I really think it ’ s time to lay down the law . __eou__ You ’ re right . Everything is probably going to come to a head to

## Create formatted data file

We'll create a formatted data file in which each line contains a tab-separated query sentence and a response sentence pair.

The following functions pase the `dailogues_text.txt` data file.
- `loadLines` splits each line of the file into conversations
- `extractSentencePairs` extracts pairs of sentences from conversations

In [3]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName):
    conversations = []
    with open(fileName, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split("__eou__")
            conversations.append(values)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation[i].strip()
            targetLine = conversation[i+1].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

Now we call the above functions to create a new file : `formatted_dialogues_text.txt`

In [4]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_dialogues_text.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

print("\nLoading conversations...")
conversations = loadLines(os.path.join(corpus, "dialogues_text.txt"))

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Loading conversations...

Writing newly formatted file...

Sample lines from file:
The kitchen stinks .	I'll throw out the garbage .

So Dick , how about getting some coffee for tonight ?	Coffee ? I don ’ t honestly like that kind of stuff .

Coffee ? I don ’ t honestly like that kind of stuff .	Come on , you can at least try a little , besides your cigarette .

Come on , you can at least try a little , besides your cigarette .	What ’ s wrong with that ? Cigarette is the thing I go crazy for .

What ’ s wrong with that ? Cigarette is the thing I go crazy for .	Not for me , Dick .

Are things still going badly with your houseguest ?	Getting worse . Now he ’ s eating me out of house and home . I ’ Ve tried talking to him but it all goes in one ear and out the other . He makes himself at home , which is fine . But what really gets me is that yesterday he walked into the living room in the raw and I had company over ! That was the last straw .

Getting worse . Now he ’ s eating me out of 

## Load and trim data

Now let's crzate a vocabulary and load query/response sentence pairs into memory.

First we must create a mapping of each word to a discrete numerical space (the index value).

Voc class keeps the mapping from words to indexes, a reverse mapping of indexes to words, a count of each word and a total word count.
There are 3 central methods :
- `addWord` to add a word to the vocabulary
- `addSentence` to add all words in a sentence
- `trim` for trimming infrequently seen words

In [5]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

Before assemble our vocabulary and query/response sentence pairs we must perform some preprocessing.

1. Convert the Unicode strings to ASCII with `unicodeToAscii`.
2. Convert all letters to lowercase and trim all non-letter characters except basic punctuation `normalizeString`
3. Filter out sentences witg length greater than the `MAX_LENGTH` threshold in `filterPairs`


In [6]:
MAX_LENGTH = 15  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

Finally assmble voc and pairs

In [7]:
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 89862 sentence pairs
Trimmed to 43724 sentence pairs
Counting words...
Counted words: 10119

pairs:
['the kitchen stinks .', 'i ll throw out the garbage .']
['so dick how about getting some coffee for tonight ?', 'coffee ? i don t honestly like that kind of stuff .']
['coffee ? i don t honestly like that kind of stuff .', 'come on you can at least try a little besides your cigarette .']
['would you mind waiting a while ?', 'well how long will it be ?']
['i swear i m going to kill you for this .', 'what s wrong ? didn t you think it was fun ? !']
['never ! but thank you for inviting me .', 'come on . you ll feel better after we hit the showers .']
['certainly . how about spaghetti with clams and shrimps .', 'sounds delicious . ok . she ll try that .']
['can you manage chopsticks ?', 'why not ? see .']
['why not ? see .', 'good mastery . how do you like our chinese food ?']
['i m exhausted .', 'okay let s go home .']


### Trimming rarely used words out of vocab

One tactic beneficial to achieve faster convergence during training is trimming rarely used words out of our vocabulary.

1. Trim words used under `MIN_COUNT` threshold using `voc.trim`
2. Filter out pairs with trimmed words

In [8]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 6157 / 10116 = 0.6086
Trimmed from 43724 pairs to 38716, 0.8855 of total


## Prepare Data for Models

BATCH TECHNIQUE

To accomodate sentences of different sizes in the same batch we make our batched input tensor of shape `(max_length, batch_size)` where sentences shorter than the max_length are zeropadded after the `EOS_token`.

- `inputVar` function handles the process of converting sentences to tensor. It returns a tensor of `lengths` for each sequence in the batch for the decoder.
- `outputVar` function performs the same as `inputVar` but returns a binary mask tensor and a maximum target sentence length.
- `batch2TrainData` takes a bunch of pairs and returns the input and target tensors

In [9]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[ 140,   47,  161,   86,  103],
        [  31,   31,  101,   90,    6],
        [ 504,   49,  699,   31,  266],
        [  18, 1829,   27,  268,    3],
        [ 536,  152, 4040,   25, 1737],
        [ 119, 1322,   93,   20,    5],
        [ 347,  251,   33,    2,    2],
        [2058,  585,  449,    0,    0],
        [ 485, 1236, 1610,    0,    0],
        [ 134, 1237,    5,    0,    0],
        [  49,   20,    2,    0,    0],
        [  20,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([13, 12, 11,  7,  7])
target_variable: tensor([[   6,    6,    6,   45,   91],
        [ 134,  266,  580, 1630,  446],
        [  53,   53,   22,  114,   19],
        [ 587,  876,  144, 5429,  154],
        [ 119,   36,  647,   78, 2518],
        [   3, 1236,  143,   29,    5],
        [1113, 1237,   65,   30,    2],
        [ 552,    5,  418,  428,    0],
        [  27,   90,    6,    5,    0],
        [ 141,   31,    7,    2,    0],
        

## Seq2Seq

The brain of our chatbot is a sequence to sequence model.

One RNN acts as an _encoder_  which encodes a variable length input sequence to a fixed-length context vector (the final hidden layer of the RNN).
The second RNN is a _decoder_ which takes a s input a word and a context vector and returns a guess for the next word in the sequence.

 ### Encoder

The encoder RNN iterates through tokens and outputs an "output" vector and a "hidden sate" vector. the hidden state vector is passed to the next time step while the output vector is recorder.
The encoder transforms the context it saw at each point in the sequence into a set of points in a high dimensional space. The decoder will use it to generate the outputted word.

We use a bidirectional multi-layered Gated Recurrent Unit.
It gives the advantage of encoding both past and future context !

Computation Graph :
1. Convert word indexed to embeddings
2. Pack padded batch of sequences for RNN module
3. Forward pass through GRU
4. Unpack padding
5. Sum bidirectional GRU outputs
6. Return output and final hidden state

In [10]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

### Decoder

The decoder RNN uses the encoder's context vectors and internal hidden states to generate the next word of the sequence.
It continues generating words until an `EOS_token`.
The problem with a vanilla seq2seq decoder is that if we rely woley on the context vector it will have information loss. (especially with long input sequences).

-> `attention mechanism` allows the decoder to pay attention to certain parts of the input sequence.

In [11]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

Now that the attention submodule is implemented let's dive into the actual decoder model.

Computation Graph:
1. Get embedding of current input word.
2. Forward through unidirectional GRU
3. Calculate attention weights from the current GRU output
4. Multiply attention weights to encoder outputs to get a new context vector
5. Concatenate weighted context vector and GRU using Luong eq
6. Predict next word
7. Return output and final hidden state

In [12]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

## Training 

`maskNNLLLoss` calculates the average negative log likelihood of the elements that correspond to a 1 in the mask tensor.

In [13]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

Single training iteration (single batch of inputs)

Couple of clever tricks :
- `teacher forcing` at some probability (set by `teacher_forcing_ratio`) current target word is used as the decoder's next input rather than using the decoder's current guess.
- `gradient clipping` commonly used technique for countering the "exploding gradient" problem.

Sequence of operations :
1. Forward pass entire input batch through encoder
2. initialize decoder inputs as SOS_token and hidden state as the encoder's final hidden state
3. Forward input batch sequence through decoder one time step at a time
4. If teacher forcing : set next decoder input as the current target else : set next decoder input as the current decoder ouptput
5. Calculate and accumulate loss
6. Perform backpropagation
7. Clip gradients
8. Update encoder and decoder

In [14]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    lengths = lengths.to("cpu")

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

Training iterations + save our model to run inferences or continue training !

In [15]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

## Gready decoding for generating sentences

Decoding method when training is not using teacher forcing. At each time step we choose the word from decoder_output with the highest softmax value.

In [16]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

## Evaluation of model

In [17]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

## Run model

possible to laod from checkpoint

In [18]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 5000 # 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))
#print(loadFilename)

# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


## Train !

In [19]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 5000 #4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.7296
Iteration: 2; Percent complete: 0.0%; Average loss: 8.6548
Iteration: 3; Percent complete: 0.1%; Average loss: 8.5581
Iteration: 4; Percent complete: 0.1%; Average loss: 8.3776
Iteration: 5; Percent complete: 0.1%; Average loss: 8.1128
Iteration: 6; Percent complete: 0.1%; Average loss: 7.7241
Iteration: 7; Percent complete: 0.1%; Average loss: 7.3037
Iteration: 8; Percent complete: 0.2%; Average loss: 7.2076
Iteration: 9; Percent complete: 0.2%; Average loss: 7.2076
Iteration: 10; Percent complete: 0.2%; Average loss: 6.9492
Iteration: 11; Percent complete: 0.2%; Average loss: 6.6086
Iteration: 12; Percent complete: 0.2%; Average loss: 6.2786
Iteration: 13; Percent complete: 0.3%; Average loss: 5.9998
Iteration: 14; Percent complete: 0.3%; Average loss: 5.8282
Iteration: 15; Percent complete: 0.3%; Average loss: 5.9165
Iteration: 16; Percent complete: 0.3%

Iteration: 138; Percent complete: 2.8%; Average loss: 4.4163
Iteration: 139; Percent complete: 2.8%; Average loss: 4.5547
Iteration: 140; Percent complete: 2.8%; Average loss: 4.3723
Iteration: 141; Percent complete: 2.8%; Average loss: 4.5710
Iteration: 142; Percent complete: 2.8%; Average loss: 4.4606
Iteration: 143; Percent complete: 2.9%; Average loss: 4.5065
Iteration: 144; Percent complete: 2.9%; Average loss: 4.7042
Iteration: 145; Percent complete: 2.9%; Average loss: 4.3401
Iteration: 146; Percent complete: 2.9%; Average loss: 4.4223
Iteration: 147; Percent complete: 2.9%; Average loss: 4.4856
Iteration: 148; Percent complete: 3.0%; Average loss: 4.4403
Iteration: 149; Percent complete: 3.0%; Average loss: 4.6620
Iteration: 150; Percent complete: 3.0%; Average loss: 4.5207
Iteration: 151; Percent complete: 3.0%; Average loss: 4.4472
Iteration: 152; Percent complete: 3.0%; Average loss: 4.5033
Iteration: 153; Percent complete: 3.1%; Average loss: 4.2981
Iteration: 154; Percent 

Iteration: 273; Percent complete: 5.5%; Average loss: 3.9868
Iteration: 274; Percent complete: 5.5%; Average loss: 3.8176
Iteration: 275; Percent complete: 5.5%; Average loss: 3.8310
Iteration: 276; Percent complete: 5.5%; Average loss: 4.0078
Iteration: 277; Percent complete: 5.5%; Average loss: 4.0143
Iteration: 278; Percent complete: 5.6%; Average loss: 4.0499
Iteration: 279; Percent complete: 5.6%; Average loss: 4.3931
Iteration: 280; Percent complete: 5.6%; Average loss: 3.8173
Iteration: 281; Percent complete: 5.6%; Average loss: 4.0902
Iteration: 282; Percent complete: 5.6%; Average loss: 4.1898
Iteration: 283; Percent complete: 5.7%; Average loss: 3.9364
Iteration: 284; Percent complete: 5.7%; Average loss: 3.8634
Iteration: 285; Percent complete: 5.7%; Average loss: 4.1002
Iteration: 286; Percent complete: 5.7%; Average loss: 4.1174
Iteration: 287; Percent complete: 5.7%; Average loss: 4.1497
Iteration: 288; Percent complete: 5.8%; Average loss: 3.6731
Iteration: 289; Percent 

Iteration: 408; Percent complete: 8.2%; Average loss: 3.6319
Iteration: 409; Percent complete: 8.2%; Average loss: 3.9601
Iteration: 410; Percent complete: 8.2%; Average loss: 3.7076
Iteration: 411; Percent complete: 8.2%; Average loss: 3.8624
Iteration: 412; Percent complete: 8.2%; Average loss: 3.7134
Iteration: 413; Percent complete: 8.3%; Average loss: 3.7061
Iteration: 414; Percent complete: 8.3%; Average loss: 3.8810
Iteration: 415; Percent complete: 8.3%; Average loss: 3.8482
Iteration: 416; Percent complete: 8.3%; Average loss: 3.8468
Iteration: 417; Percent complete: 8.3%; Average loss: 3.5818
Iteration: 418; Percent complete: 8.4%; Average loss: 4.0741
Iteration: 419; Percent complete: 8.4%; Average loss: 3.4719
Iteration: 420; Percent complete: 8.4%; Average loss: 3.6679
Iteration: 421; Percent complete: 8.4%; Average loss: 3.8565
Iteration: 422; Percent complete: 8.4%; Average loss: 3.9066
Iteration: 423; Percent complete: 8.5%; Average loss: 3.7390
Iteration: 424; Percent 

Iteration: 543; Percent complete: 10.9%; Average loss: 3.6605
Iteration: 544; Percent complete: 10.9%; Average loss: 3.8485
Iteration: 545; Percent complete: 10.9%; Average loss: 3.5943
Iteration: 546; Percent complete: 10.9%; Average loss: 3.5785
Iteration: 547; Percent complete: 10.9%; Average loss: 3.9246
Iteration: 548; Percent complete: 11.0%; Average loss: 3.7404
Iteration: 549; Percent complete: 11.0%; Average loss: 3.5495
Iteration: 550; Percent complete: 11.0%; Average loss: 3.9756
Iteration: 551; Percent complete: 11.0%; Average loss: 3.6940
Iteration: 552; Percent complete: 11.0%; Average loss: 3.5288
Iteration: 553; Percent complete: 11.1%; Average loss: 3.5778
Iteration: 554; Percent complete: 11.1%; Average loss: 3.6650
Iteration: 555; Percent complete: 11.1%; Average loss: 3.7118
Iteration: 556; Percent complete: 11.1%; Average loss: 3.7294
Iteration: 557; Percent complete: 11.1%; Average loss: 3.6894
Iteration: 558; Percent complete: 11.2%; Average loss: 3.3654
Iteratio

Iteration: 678; Percent complete: 13.6%; Average loss: 3.6615
Iteration: 679; Percent complete: 13.6%; Average loss: 3.5788
Iteration: 680; Percent complete: 13.6%; Average loss: 3.8643
Iteration: 681; Percent complete: 13.6%; Average loss: 3.5501
Iteration: 682; Percent complete: 13.6%; Average loss: 3.6330
Iteration: 683; Percent complete: 13.7%; Average loss: 3.4624
Iteration: 684; Percent complete: 13.7%; Average loss: 3.5530
Iteration: 685; Percent complete: 13.7%; Average loss: 4.0160
Iteration: 686; Percent complete: 13.7%; Average loss: 3.5731
Iteration: 687; Percent complete: 13.7%; Average loss: 3.6552
Iteration: 688; Percent complete: 13.8%; Average loss: 3.7312
Iteration: 689; Percent complete: 13.8%; Average loss: 3.6400
Iteration: 690; Percent complete: 13.8%; Average loss: 3.6660
Iteration: 691; Percent complete: 13.8%; Average loss: 3.5243
Iteration: 692; Percent complete: 13.8%; Average loss: 3.5611
Iteration: 693; Percent complete: 13.9%; Average loss: 3.6742
Iteratio

Iteration: 812; Percent complete: 16.2%; Average loss: 3.1794
Iteration: 813; Percent complete: 16.3%; Average loss: 3.6811
Iteration: 814; Percent complete: 16.3%; Average loss: 3.4303
Iteration: 815; Percent complete: 16.3%; Average loss: 3.4517
Iteration: 816; Percent complete: 16.3%; Average loss: 3.4128
Iteration: 817; Percent complete: 16.3%; Average loss: 3.4981
Iteration: 818; Percent complete: 16.4%; Average loss: 3.6443
Iteration: 819; Percent complete: 16.4%; Average loss: 3.5296
Iteration: 820; Percent complete: 16.4%; Average loss: 3.3839
Iteration: 821; Percent complete: 16.4%; Average loss: 3.4011
Iteration: 822; Percent complete: 16.4%; Average loss: 3.3374
Iteration: 823; Percent complete: 16.5%; Average loss: 3.2692
Iteration: 824; Percent complete: 16.5%; Average loss: 3.1386
Iteration: 825; Percent complete: 16.5%; Average loss: 3.6107
Iteration: 826; Percent complete: 16.5%; Average loss: 3.4273
Iteration: 827; Percent complete: 16.5%; Average loss: 3.4099
Iteratio

Iteration: 947; Percent complete: 18.9%; Average loss: 3.3634
Iteration: 948; Percent complete: 19.0%; Average loss: 3.3811
Iteration: 949; Percent complete: 19.0%; Average loss: 3.5291
Iteration: 950; Percent complete: 19.0%; Average loss: 3.3660
Iteration: 951; Percent complete: 19.0%; Average loss: 3.3888
Iteration: 952; Percent complete: 19.0%; Average loss: 3.2561
Iteration: 953; Percent complete: 19.1%; Average loss: 3.5312
Iteration: 954; Percent complete: 19.1%; Average loss: 3.1459
Iteration: 955; Percent complete: 19.1%; Average loss: 3.4241
Iteration: 956; Percent complete: 19.1%; Average loss: 3.5017
Iteration: 957; Percent complete: 19.1%; Average loss: 3.0830
Iteration: 958; Percent complete: 19.2%; Average loss: 3.4289
Iteration: 959; Percent complete: 19.2%; Average loss: 3.2650
Iteration: 960; Percent complete: 19.2%; Average loss: 3.4806
Iteration: 961; Percent complete: 19.2%; Average loss: 3.3796
Iteration: 962; Percent complete: 19.2%; Average loss: 3.5304
Iteratio

Iteration: 1079; Percent complete: 21.6%; Average loss: 3.1691
Iteration: 1080; Percent complete: 21.6%; Average loss: 3.2204
Iteration: 1081; Percent complete: 21.6%; Average loss: 3.3960
Iteration: 1082; Percent complete: 21.6%; Average loss: 3.1879
Iteration: 1083; Percent complete: 21.7%; Average loss: 3.1803
Iteration: 1084; Percent complete: 21.7%; Average loss: 3.3239
Iteration: 1085; Percent complete: 21.7%; Average loss: 3.3792
Iteration: 1086; Percent complete: 21.7%; Average loss: 3.4982
Iteration: 1087; Percent complete: 21.7%; Average loss: 3.3088
Iteration: 1088; Percent complete: 21.8%; Average loss: 3.0165
Iteration: 1089; Percent complete: 21.8%; Average loss: 3.3623
Iteration: 1090; Percent complete: 21.8%; Average loss: 3.3152
Iteration: 1091; Percent complete: 21.8%; Average loss: 3.4649
Iteration: 1092; Percent complete: 21.8%; Average loss: 3.2285
Iteration: 1093; Percent complete: 21.9%; Average loss: 3.3655
Iteration: 1094; Percent complete: 21.9%; Average loss:

Iteration: 1211; Percent complete: 24.2%; Average loss: 3.2746
Iteration: 1212; Percent complete: 24.2%; Average loss: 3.1058
Iteration: 1213; Percent complete: 24.3%; Average loss: 3.1832
Iteration: 1214; Percent complete: 24.3%; Average loss: 3.2417
Iteration: 1215; Percent complete: 24.3%; Average loss: 3.0905
Iteration: 1216; Percent complete: 24.3%; Average loss: 3.2891
Iteration: 1217; Percent complete: 24.3%; Average loss: 3.1159
Iteration: 1218; Percent complete: 24.4%; Average loss: 2.9869
Iteration: 1219; Percent complete: 24.4%; Average loss: 2.9963
Iteration: 1220; Percent complete: 24.4%; Average loss: 2.9976
Iteration: 1221; Percent complete: 24.4%; Average loss: 3.2476
Iteration: 1222; Percent complete: 24.4%; Average loss: 3.1066
Iteration: 1223; Percent complete: 24.5%; Average loss: 3.3584
Iteration: 1224; Percent complete: 24.5%; Average loss: 3.2625
Iteration: 1225; Percent complete: 24.5%; Average loss: 3.0067
Iteration: 1226; Percent complete: 24.5%; Average loss:

Iteration: 1343; Percent complete: 26.9%; Average loss: 3.2109
Iteration: 1344; Percent complete: 26.9%; Average loss: 3.2907
Iteration: 1345; Percent complete: 26.9%; Average loss: 3.0875
Iteration: 1346; Percent complete: 26.9%; Average loss: 3.0441
Iteration: 1347; Percent complete: 26.9%; Average loss: 3.0189
Iteration: 1348; Percent complete: 27.0%; Average loss: 3.3660
Iteration: 1349; Percent complete: 27.0%; Average loss: 2.8149
Iteration: 1350; Percent complete: 27.0%; Average loss: 3.2276
Iteration: 1351; Percent complete: 27.0%; Average loss: 3.1408
Iteration: 1352; Percent complete: 27.0%; Average loss: 2.8536
Iteration: 1353; Percent complete: 27.1%; Average loss: 3.1586
Iteration: 1354; Percent complete: 27.1%; Average loss: 3.0601
Iteration: 1355; Percent complete: 27.1%; Average loss: 2.8859
Iteration: 1356; Percent complete: 27.1%; Average loss: 2.9785
Iteration: 1357; Percent complete: 27.1%; Average loss: 2.9438
Iteration: 1358; Percent complete: 27.2%; Average loss:

Iteration: 1475; Percent complete: 29.5%; Average loss: 2.8063
Iteration: 1476; Percent complete: 29.5%; Average loss: 3.0611
Iteration: 1477; Percent complete: 29.5%; Average loss: 2.9834
Iteration: 1478; Percent complete: 29.6%; Average loss: 2.9333
Iteration: 1479; Percent complete: 29.6%; Average loss: 3.0133
Iteration: 1480; Percent complete: 29.6%; Average loss: 3.1296
Iteration: 1481; Percent complete: 29.6%; Average loss: 2.9884
Iteration: 1482; Percent complete: 29.6%; Average loss: 3.0270
Iteration: 1483; Percent complete: 29.7%; Average loss: 3.0664
Iteration: 1484; Percent complete: 29.7%; Average loss: 2.8157
Iteration: 1485; Percent complete: 29.7%; Average loss: 2.9103
Iteration: 1486; Percent complete: 29.7%; Average loss: 3.2122
Iteration: 1487; Percent complete: 29.7%; Average loss: 2.8738
Iteration: 1488; Percent complete: 29.8%; Average loss: 3.0347
Iteration: 1489; Percent complete: 29.8%; Average loss: 3.1452
Iteration: 1490; Percent complete: 29.8%; Average loss:

Iteration: 1606; Percent complete: 32.1%; Average loss: 2.8666
Iteration: 1607; Percent complete: 32.1%; Average loss: 2.8748
Iteration: 1608; Percent complete: 32.2%; Average loss: 2.9261
Iteration: 1609; Percent complete: 32.2%; Average loss: 2.9532
Iteration: 1610; Percent complete: 32.2%; Average loss: 2.9109
Iteration: 1611; Percent complete: 32.2%; Average loss: 3.1058
Iteration: 1612; Percent complete: 32.2%; Average loss: 2.9569
Iteration: 1613; Percent complete: 32.3%; Average loss: 3.0975
Iteration: 1614; Percent complete: 32.3%; Average loss: 3.1240
Iteration: 1615; Percent complete: 32.3%; Average loss: 2.9168
Iteration: 1616; Percent complete: 32.3%; Average loss: 2.7841
Iteration: 1617; Percent complete: 32.3%; Average loss: 3.1047
Iteration: 1618; Percent complete: 32.4%; Average loss: 3.0608
Iteration: 1619; Percent complete: 32.4%; Average loss: 2.7758
Iteration: 1620; Percent complete: 32.4%; Average loss: 2.8262
Iteration: 1621; Percent complete: 32.4%; Average loss:

Iteration: 1738; Percent complete: 34.8%; Average loss: 2.9875
Iteration: 1739; Percent complete: 34.8%; Average loss: 2.9907
Iteration: 1740; Percent complete: 34.8%; Average loss: 2.5751
Iteration: 1741; Percent complete: 34.8%; Average loss: 2.7862
Iteration: 1742; Percent complete: 34.8%; Average loss: 2.8727
Iteration: 1743; Percent complete: 34.9%; Average loss: 2.8742
Iteration: 1744; Percent complete: 34.9%; Average loss: 2.9844
Iteration: 1745; Percent complete: 34.9%; Average loss: 2.9667
Iteration: 1746; Percent complete: 34.9%; Average loss: 3.1219
Iteration: 1747; Percent complete: 34.9%; Average loss: 2.7343
Iteration: 1748; Percent complete: 35.0%; Average loss: 2.9682
Iteration: 1749; Percent complete: 35.0%; Average loss: 3.1064
Iteration: 1750; Percent complete: 35.0%; Average loss: 2.9227
Iteration: 1751; Percent complete: 35.0%; Average loss: 3.2059
Iteration: 1752; Percent complete: 35.0%; Average loss: 2.9021
Iteration: 1753; Percent complete: 35.1%; Average loss:

Iteration: 1870; Percent complete: 37.4%; Average loss: 2.8795
Iteration: 1871; Percent complete: 37.4%; Average loss: 2.8040
Iteration: 1872; Percent complete: 37.4%; Average loss: 2.9257
Iteration: 1873; Percent complete: 37.5%; Average loss: 2.6280
Iteration: 1874; Percent complete: 37.5%; Average loss: 2.3977
Iteration: 1875; Percent complete: 37.5%; Average loss: 2.6688
Iteration: 1876; Percent complete: 37.5%; Average loss: 2.8177
Iteration: 1877; Percent complete: 37.5%; Average loss: 2.7954
Iteration: 1878; Percent complete: 37.6%; Average loss: 2.6406
Iteration: 1879; Percent complete: 37.6%; Average loss: 2.9158
Iteration: 1880; Percent complete: 37.6%; Average loss: 2.8788
Iteration: 1881; Percent complete: 37.6%; Average loss: 2.8108
Iteration: 1882; Percent complete: 37.6%; Average loss: 2.8263
Iteration: 1883; Percent complete: 37.7%; Average loss: 2.7758
Iteration: 1884; Percent complete: 37.7%; Average loss: 2.7790
Iteration: 1885; Percent complete: 37.7%; Average loss:

Iteration: 2001; Percent complete: 40.0%; Average loss: 2.8909
Iteration: 2002; Percent complete: 40.0%; Average loss: 2.5568
Iteration: 2003; Percent complete: 40.1%; Average loss: 2.5181
Iteration: 2004; Percent complete: 40.1%; Average loss: 2.8491
Iteration: 2005; Percent complete: 40.1%; Average loss: 2.7213
Iteration: 2006; Percent complete: 40.1%; Average loss: 3.0049
Iteration: 2007; Percent complete: 40.1%; Average loss: 2.8065
Iteration: 2008; Percent complete: 40.2%; Average loss: 2.6018
Iteration: 2009; Percent complete: 40.2%; Average loss: 2.7496
Iteration: 2010; Percent complete: 40.2%; Average loss: 2.8840
Iteration: 2011; Percent complete: 40.2%; Average loss: 2.7125
Iteration: 2012; Percent complete: 40.2%; Average loss: 2.6427
Iteration: 2013; Percent complete: 40.3%; Average loss: 2.9260
Iteration: 2014; Percent complete: 40.3%; Average loss: 2.5678
Iteration: 2015; Percent complete: 40.3%; Average loss: 2.9163
Iteration: 2016; Percent complete: 40.3%; Average loss:

Iteration: 2133; Percent complete: 42.7%; Average loss: 2.6862
Iteration: 2134; Percent complete: 42.7%; Average loss: 2.6691
Iteration: 2135; Percent complete: 42.7%; Average loss: 2.8210
Iteration: 2136; Percent complete: 42.7%; Average loss: 2.5166
Iteration: 2137; Percent complete: 42.7%; Average loss: 2.7406
Iteration: 2138; Percent complete: 42.8%; Average loss: 2.7290
Iteration: 2139; Percent complete: 42.8%; Average loss: 2.8623
Iteration: 2140; Percent complete: 42.8%; Average loss: 2.5527
Iteration: 2141; Percent complete: 42.8%; Average loss: 2.7951
Iteration: 2142; Percent complete: 42.8%; Average loss: 2.8187
Iteration: 2143; Percent complete: 42.9%; Average loss: 2.5389
Iteration: 2144; Percent complete: 42.9%; Average loss: 2.8424
Iteration: 2145; Percent complete: 42.9%; Average loss: 2.8877
Iteration: 2146; Percent complete: 42.9%; Average loss: 2.8038
Iteration: 2147; Percent complete: 42.9%; Average loss: 2.7270
Iteration: 2148; Percent complete: 43.0%; Average loss:

Iteration: 2265; Percent complete: 45.3%; Average loss: 2.7870
Iteration: 2266; Percent complete: 45.3%; Average loss: 2.5187
Iteration: 2267; Percent complete: 45.3%; Average loss: 2.5681
Iteration: 2268; Percent complete: 45.4%; Average loss: 2.6474
Iteration: 2269; Percent complete: 45.4%; Average loss: 2.7794
Iteration: 2270; Percent complete: 45.4%; Average loss: 2.5151
Iteration: 2271; Percent complete: 45.4%; Average loss: 2.5342
Iteration: 2272; Percent complete: 45.4%; Average loss: 2.8701
Iteration: 2273; Percent complete: 45.5%; Average loss: 2.4696
Iteration: 2274; Percent complete: 45.5%; Average loss: 2.6071
Iteration: 2275; Percent complete: 45.5%; Average loss: 2.7079
Iteration: 2276; Percent complete: 45.5%; Average loss: 2.7620
Iteration: 2277; Percent complete: 45.5%; Average loss: 2.7254
Iteration: 2278; Percent complete: 45.6%; Average loss: 2.5275
Iteration: 2279; Percent complete: 45.6%; Average loss: 2.4384
Iteration: 2280; Percent complete: 45.6%; Average loss:

Iteration: 2396; Percent complete: 47.9%; Average loss: 2.4323
Iteration: 2397; Percent complete: 47.9%; Average loss: 2.2800
Iteration: 2398; Percent complete: 48.0%; Average loss: 2.8490
Iteration: 2399; Percent complete: 48.0%; Average loss: 2.8936
Iteration: 2400; Percent complete: 48.0%; Average loss: 2.6851
Iteration: 2401; Percent complete: 48.0%; Average loss: 2.5474
Iteration: 2402; Percent complete: 48.0%; Average loss: 2.5077
Iteration: 2403; Percent complete: 48.1%; Average loss: 2.5348
Iteration: 2404; Percent complete: 48.1%; Average loss: 2.5674
Iteration: 2405; Percent complete: 48.1%; Average loss: 2.5337
Iteration: 2406; Percent complete: 48.1%; Average loss: 2.6265
Iteration: 2407; Percent complete: 48.1%; Average loss: 2.5710
Iteration: 2408; Percent complete: 48.2%; Average loss: 2.5463
Iteration: 2409; Percent complete: 48.2%; Average loss: 2.4607
Iteration: 2410; Percent complete: 48.2%; Average loss: 2.6795
Iteration: 2411; Percent complete: 48.2%; Average loss:

Iteration: 2528; Percent complete: 50.6%; Average loss: 2.4412
Iteration: 2529; Percent complete: 50.6%; Average loss: 2.5953
Iteration: 2530; Percent complete: 50.6%; Average loss: 2.5805
Iteration: 2531; Percent complete: 50.6%; Average loss: 2.4910
Iteration: 2532; Percent complete: 50.6%; Average loss: 2.5969
Iteration: 2533; Percent complete: 50.7%; Average loss: 2.6008
Iteration: 2534; Percent complete: 50.7%; Average loss: 2.4965
Iteration: 2535; Percent complete: 50.7%; Average loss: 2.4177
Iteration: 2536; Percent complete: 50.7%; Average loss: 2.6432
Iteration: 2537; Percent complete: 50.7%; Average loss: 2.5237
Iteration: 2538; Percent complete: 50.8%; Average loss: 2.5787
Iteration: 2539; Percent complete: 50.8%; Average loss: 2.5975
Iteration: 2540; Percent complete: 50.8%; Average loss: 2.6184
Iteration: 2541; Percent complete: 50.8%; Average loss: 2.5680
Iteration: 2542; Percent complete: 50.8%; Average loss: 2.5941
Iteration: 2543; Percent complete: 50.9%; Average loss:

Iteration: 2660; Percent complete: 53.2%; Average loss: 2.4182
Iteration: 2661; Percent complete: 53.2%; Average loss: 2.5190
Iteration: 2662; Percent complete: 53.2%; Average loss: 2.4837
Iteration: 2663; Percent complete: 53.3%; Average loss: 2.4692
Iteration: 2664; Percent complete: 53.3%; Average loss: 2.3945
Iteration: 2665; Percent complete: 53.3%; Average loss: 2.2900
Iteration: 2666; Percent complete: 53.3%; Average loss: 2.5811
Iteration: 2667; Percent complete: 53.3%; Average loss: 2.3335
Iteration: 2668; Percent complete: 53.4%; Average loss: 2.2947
Iteration: 2669; Percent complete: 53.4%; Average loss: 2.6512
Iteration: 2670; Percent complete: 53.4%; Average loss: 2.5148
Iteration: 2671; Percent complete: 53.4%; Average loss: 2.6442
Iteration: 2672; Percent complete: 53.4%; Average loss: 2.2809
Iteration: 2673; Percent complete: 53.5%; Average loss: 2.7285
Iteration: 2674; Percent complete: 53.5%; Average loss: 2.2081
Iteration: 2675; Percent complete: 53.5%; Average loss:

Iteration: 2791; Percent complete: 55.8%; Average loss: 2.1334
Iteration: 2792; Percent complete: 55.8%; Average loss: 2.5605
Iteration: 2793; Percent complete: 55.9%; Average loss: 2.5533
Iteration: 2794; Percent complete: 55.9%; Average loss: 2.5510
Iteration: 2795; Percent complete: 55.9%; Average loss: 2.2525
Iteration: 2796; Percent complete: 55.9%; Average loss: 2.6117
Iteration: 2797; Percent complete: 55.9%; Average loss: 2.5800
Iteration: 2798; Percent complete: 56.0%; Average loss: 2.6162
Iteration: 2799; Percent complete: 56.0%; Average loss: 2.3936
Iteration: 2800; Percent complete: 56.0%; Average loss: 2.5183
Iteration: 2801; Percent complete: 56.0%; Average loss: 2.4741
Iteration: 2802; Percent complete: 56.0%; Average loss: 2.3565
Iteration: 2803; Percent complete: 56.1%; Average loss: 2.6149
Iteration: 2804; Percent complete: 56.1%; Average loss: 2.4856
Iteration: 2805; Percent complete: 56.1%; Average loss: 2.3432
Iteration: 2806; Percent complete: 56.1%; Average loss:

Iteration: 2922; Percent complete: 58.4%; Average loss: 2.3335
Iteration: 2923; Percent complete: 58.5%; Average loss: 2.2283
Iteration: 2924; Percent complete: 58.5%; Average loss: 2.3242
Iteration: 2925; Percent complete: 58.5%; Average loss: 2.5393
Iteration: 2926; Percent complete: 58.5%; Average loss: 2.4601
Iteration: 2927; Percent complete: 58.5%; Average loss: 2.4292
Iteration: 2928; Percent complete: 58.6%; Average loss: 2.3645
Iteration: 2929; Percent complete: 58.6%; Average loss: 2.5255
Iteration: 2930; Percent complete: 58.6%; Average loss: 2.2337
Iteration: 2931; Percent complete: 58.6%; Average loss: 2.4390
Iteration: 2932; Percent complete: 58.6%; Average loss: 2.5360
Iteration: 2933; Percent complete: 58.7%; Average loss: 2.2807
Iteration: 2934; Percent complete: 58.7%; Average loss: 2.3599
Iteration: 2935; Percent complete: 58.7%; Average loss: 2.2772
Iteration: 2936; Percent complete: 58.7%; Average loss: 2.2669
Iteration: 2937; Percent complete: 58.7%; Average loss:

Iteration: 3055; Percent complete: 61.1%; Average loss: 2.1651
Iteration: 3056; Percent complete: 61.1%; Average loss: 2.1762
Iteration: 3057; Percent complete: 61.1%; Average loss: 2.4538
Iteration: 3058; Percent complete: 61.2%; Average loss: 2.2187
Iteration: 3059; Percent complete: 61.2%; Average loss: 2.3276
Iteration: 3060; Percent complete: 61.2%; Average loss: 2.4309
Iteration: 3061; Percent complete: 61.2%; Average loss: 2.2835
Iteration: 3062; Percent complete: 61.2%; Average loss: 2.5348
Iteration: 3063; Percent complete: 61.3%; Average loss: 2.2273
Iteration: 3064; Percent complete: 61.3%; Average loss: 2.3586
Iteration: 3065; Percent complete: 61.3%; Average loss: 2.4199
Iteration: 3066; Percent complete: 61.3%; Average loss: 2.3299
Iteration: 3067; Percent complete: 61.3%; Average loss: 2.2928
Iteration: 3068; Percent complete: 61.4%; Average loss: 2.1757
Iteration: 3069; Percent complete: 61.4%; Average loss: 2.2413
Iteration: 3070; Percent complete: 61.4%; Average loss:

Iteration: 3188; Percent complete: 63.8%; Average loss: 2.3086
Iteration: 3189; Percent complete: 63.8%; Average loss: 2.3552
Iteration: 3190; Percent complete: 63.8%; Average loss: 2.3162
Iteration: 3191; Percent complete: 63.8%; Average loss: 2.4229
Iteration: 3192; Percent complete: 63.8%; Average loss: 2.0934
Iteration: 3193; Percent complete: 63.9%; Average loss: 2.3661
Iteration: 3194; Percent complete: 63.9%; Average loss: 2.0747
Iteration: 3195; Percent complete: 63.9%; Average loss: 2.3256
Iteration: 3196; Percent complete: 63.9%; Average loss: 2.1824
Iteration: 3197; Percent complete: 63.9%; Average loss: 2.1948
Iteration: 3198; Percent complete: 64.0%; Average loss: 2.2926
Iteration: 3199; Percent complete: 64.0%; Average loss: 2.2652
Iteration: 3200; Percent complete: 64.0%; Average loss: 2.3002
Iteration: 3201; Percent complete: 64.0%; Average loss: 2.1418
Iteration: 3202; Percent complete: 64.0%; Average loss: 2.3825
Iteration: 3203; Percent complete: 64.1%; Average loss:

Iteration: 3320; Percent complete: 66.4%; Average loss: 2.0347
Iteration: 3321; Percent complete: 66.4%; Average loss: 2.2567
Iteration: 3322; Percent complete: 66.4%; Average loss: 2.2818
Iteration: 3323; Percent complete: 66.5%; Average loss: 2.1133
Iteration: 3324; Percent complete: 66.5%; Average loss: 2.1284
Iteration: 3325; Percent complete: 66.5%; Average loss: 2.2919
Iteration: 3326; Percent complete: 66.5%; Average loss: 2.1911
Iteration: 3327; Percent complete: 66.5%; Average loss: 2.2112
Iteration: 3328; Percent complete: 66.6%; Average loss: 1.9736
Iteration: 3329; Percent complete: 66.6%; Average loss: 2.1109
Iteration: 3330; Percent complete: 66.6%; Average loss: 2.3026
Iteration: 3331; Percent complete: 66.6%; Average loss: 2.0536
Iteration: 3332; Percent complete: 66.6%; Average loss: 2.1996
Iteration: 3333; Percent complete: 66.7%; Average loss: 2.2223
Iteration: 3334; Percent complete: 66.7%; Average loss: 2.3424
Iteration: 3335; Percent complete: 66.7%; Average loss:

Iteration: 3451; Percent complete: 69.0%; Average loss: 2.2029
Iteration: 3452; Percent complete: 69.0%; Average loss: 2.2608
Iteration: 3453; Percent complete: 69.1%; Average loss: 2.2705
Iteration: 3454; Percent complete: 69.1%; Average loss: 2.4295
Iteration: 3455; Percent complete: 69.1%; Average loss: 2.0226
Iteration: 3456; Percent complete: 69.1%; Average loss: 2.3668
Iteration: 3457; Percent complete: 69.1%; Average loss: 2.1290
Iteration: 3458; Percent complete: 69.2%; Average loss: 2.3855
Iteration: 3459; Percent complete: 69.2%; Average loss: 2.1124
Iteration: 3460; Percent complete: 69.2%; Average loss: 2.0619
Iteration: 3461; Percent complete: 69.2%; Average loss: 2.0618
Iteration: 3462; Percent complete: 69.2%; Average loss: 2.2258
Iteration: 3463; Percent complete: 69.3%; Average loss: 2.4813
Iteration: 3464; Percent complete: 69.3%; Average loss: 2.2825
Iteration: 3465; Percent complete: 69.3%; Average loss: 2.1716
Iteration: 3466; Percent complete: 69.3%; Average loss:

Iteration: 3582; Percent complete: 71.6%; Average loss: 1.9980
Iteration: 3583; Percent complete: 71.7%; Average loss: 2.3194
Iteration: 3584; Percent complete: 71.7%; Average loss: 2.0425
Iteration: 3585; Percent complete: 71.7%; Average loss: 2.0147
Iteration: 3586; Percent complete: 71.7%; Average loss: 2.1314
Iteration: 3587; Percent complete: 71.7%; Average loss: 2.0034
Iteration: 3588; Percent complete: 71.8%; Average loss: 1.9628
Iteration: 3589; Percent complete: 71.8%; Average loss: 2.1423
Iteration: 3590; Percent complete: 71.8%; Average loss: 2.0853
Iteration: 3591; Percent complete: 71.8%; Average loss: 2.1913
Iteration: 3592; Percent complete: 71.8%; Average loss: 2.1693
Iteration: 3593; Percent complete: 71.9%; Average loss: 2.0927
Iteration: 3594; Percent complete: 71.9%; Average loss: 2.0452
Iteration: 3595; Percent complete: 71.9%; Average loss: 2.0558
Iteration: 3596; Percent complete: 71.9%; Average loss: 2.1813
Iteration: 3597; Percent complete: 71.9%; Average loss:

Iteration: 3714; Percent complete: 74.3%; Average loss: 2.2363
Iteration: 3715; Percent complete: 74.3%; Average loss: 1.9983
Iteration: 3716; Percent complete: 74.3%; Average loss: 2.0901
Iteration: 3717; Percent complete: 74.3%; Average loss: 2.1998
Iteration: 3718; Percent complete: 74.4%; Average loss: 1.9991
Iteration: 3719; Percent complete: 74.4%; Average loss: 1.9701
Iteration: 3720; Percent complete: 74.4%; Average loss: 2.0396
Iteration: 3721; Percent complete: 74.4%; Average loss: 2.0934
Iteration: 3722; Percent complete: 74.4%; Average loss: 2.1451
Iteration: 3723; Percent complete: 74.5%; Average loss: 2.0436
Iteration: 3724; Percent complete: 74.5%; Average loss: 2.1734
Iteration: 3725; Percent complete: 74.5%; Average loss: 2.1094
Iteration: 3726; Percent complete: 74.5%; Average loss: 1.9715
Iteration: 3727; Percent complete: 74.5%; Average loss: 1.9490
Iteration: 3728; Percent complete: 74.6%; Average loss: 2.1743
Iteration: 3729; Percent complete: 74.6%; Average loss:

Iteration: 3846; Percent complete: 76.9%; Average loss: 1.9782
Iteration: 3847; Percent complete: 76.9%; Average loss: 2.0813
Iteration: 3848; Percent complete: 77.0%; Average loss: 1.9337
Iteration: 3849; Percent complete: 77.0%; Average loss: 1.7058
Iteration: 3850; Percent complete: 77.0%; Average loss: 2.0728
Iteration: 3851; Percent complete: 77.0%; Average loss: 2.1212
Iteration: 3852; Percent complete: 77.0%; Average loss: 2.2757
Iteration: 3853; Percent complete: 77.1%; Average loss: 2.0792
Iteration: 3854; Percent complete: 77.1%; Average loss: 1.7920
Iteration: 3855; Percent complete: 77.1%; Average loss: 2.0706
Iteration: 3856; Percent complete: 77.1%; Average loss: 2.0392
Iteration: 3857; Percent complete: 77.1%; Average loss: 1.9365
Iteration: 3858; Percent complete: 77.2%; Average loss: 2.2426
Iteration: 3859; Percent complete: 77.2%; Average loss: 1.9903
Iteration: 3860; Percent complete: 77.2%; Average loss: 1.8808
Iteration: 3861; Percent complete: 77.2%; Average loss:

Iteration: 3978; Percent complete: 79.6%; Average loss: 2.0341
Iteration: 3979; Percent complete: 79.6%; Average loss: 1.9079
Iteration: 3980; Percent complete: 79.6%; Average loss: 1.9940
Iteration: 3981; Percent complete: 79.6%; Average loss: 1.9539
Iteration: 3982; Percent complete: 79.6%; Average loss: 1.7986
Iteration: 3983; Percent complete: 79.7%; Average loss: 2.2641
Iteration: 3984; Percent complete: 79.7%; Average loss: 2.0622
Iteration: 3985; Percent complete: 79.7%; Average loss: 1.9224
Iteration: 3986; Percent complete: 79.7%; Average loss: 1.9184
Iteration: 3987; Percent complete: 79.7%; Average loss: 1.9799
Iteration: 3988; Percent complete: 79.8%; Average loss: 1.9197
Iteration: 3989; Percent complete: 79.8%; Average loss: 2.1691
Iteration: 3990; Percent complete: 79.8%; Average loss: 1.9165
Iteration: 3991; Percent complete: 79.8%; Average loss: 2.0218
Iteration: 3992; Percent complete: 79.8%; Average loss: 1.9009
Iteration: 3993; Percent complete: 79.9%; Average loss:

Iteration: 4109; Percent complete: 82.2%; Average loss: 1.9923
Iteration: 4110; Percent complete: 82.2%; Average loss: 1.8950
Iteration: 4111; Percent complete: 82.2%; Average loss: 2.1148
Iteration: 4112; Percent complete: 82.2%; Average loss: 1.7517
Iteration: 4113; Percent complete: 82.3%; Average loss: 1.8862
Iteration: 4114; Percent complete: 82.3%; Average loss: 2.0145
Iteration: 4115; Percent complete: 82.3%; Average loss: 2.1167
Iteration: 4116; Percent complete: 82.3%; Average loss: 1.8891
Iteration: 4117; Percent complete: 82.3%; Average loss: 1.8879
Iteration: 4118; Percent complete: 82.4%; Average loss: 2.0767
Iteration: 4119; Percent complete: 82.4%; Average loss: 1.9407
Iteration: 4120; Percent complete: 82.4%; Average loss: 2.0631
Iteration: 4121; Percent complete: 82.4%; Average loss: 1.9767
Iteration: 4122; Percent complete: 82.4%; Average loss: 1.8062
Iteration: 4123; Percent complete: 82.5%; Average loss: 1.8873
Iteration: 4124; Percent complete: 82.5%; Average loss:

Iteration: 4241; Percent complete: 84.8%; Average loss: 1.9945
Iteration: 4242; Percent complete: 84.8%; Average loss: 1.8802
Iteration: 4243; Percent complete: 84.9%; Average loss: 1.8799
Iteration: 4244; Percent complete: 84.9%; Average loss: 1.8833
Iteration: 4245; Percent complete: 84.9%; Average loss: 1.7102
Iteration: 4246; Percent complete: 84.9%; Average loss: 1.8145
Iteration: 4247; Percent complete: 84.9%; Average loss: 1.8199
Iteration: 4248; Percent complete: 85.0%; Average loss: 1.7235
Iteration: 4249; Percent complete: 85.0%; Average loss: 1.7706
Iteration: 4250; Percent complete: 85.0%; Average loss: 1.9157
Iteration: 4251; Percent complete: 85.0%; Average loss: 1.8954
Iteration: 4252; Percent complete: 85.0%; Average loss: 1.8975
Iteration: 4253; Percent complete: 85.1%; Average loss: 1.8293
Iteration: 4254; Percent complete: 85.1%; Average loss: 1.9179
Iteration: 4255; Percent complete: 85.1%; Average loss: 1.8773
Iteration: 4256; Percent complete: 85.1%; Average loss:

Iteration: 4373; Percent complete: 87.5%; Average loss: 1.9029
Iteration: 4374; Percent complete: 87.5%; Average loss: 2.0213
Iteration: 4375; Percent complete: 87.5%; Average loss: 1.6946
Iteration: 4376; Percent complete: 87.5%; Average loss: 1.7511
Iteration: 4377; Percent complete: 87.5%; Average loss: 1.9284
Iteration: 4378; Percent complete: 87.6%; Average loss: 1.8392
Iteration: 4379; Percent complete: 87.6%; Average loss: 1.8296
Iteration: 4380; Percent complete: 87.6%; Average loss: 1.7503
Iteration: 4381; Percent complete: 87.6%; Average loss: 1.9219
Iteration: 4382; Percent complete: 87.6%; Average loss: 1.7850
Iteration: 4383; Percent complete: 87.7%; Average loss: 1.8353
Iteration: 4384; Percent complete: 87.7%; Average loss: 1.8953
Iteration: 4385; Percent complete: 87.7%; Average loss: 1.7884
Iteration: 4386; Percent complete: 87.7%; Average loss: 1.8527
Iteration: 4387; Percent complete: 87.7%; Average loss: 1.8287
Iteration: 4388; Percent complete: 87.8%; Average loss:

Iteration: 4504; Percent complete: 90.1%; Average loss: 1.8569
Iteration: 4505; Percent complete: 90.1%; Average loss: 1.7202
Iteration: 4506; Percent complete: 90.1%; Average loss: 1.7464
Iteration: 4507; Percent complete: 90.1%; Average loss: 1.8211
Iteration: 4508; Percent complete: 90.2%; Average loss: 1.5835
Iteration: 4509; Percent complete: 90.2%; Average loss: 2.1247
Iteration: 4510; Percent complete: 90.2%; Average loss: 1.6659
Iteration: 4511; Percent complete: 90.2%; Average loss: 1.7213
Iteration: 4512; Percent complete: 90.2%; Average loss: 1.8329
Iteration: 4513; Percent complete: 90.3%; Average loss: 1.6229
Iteration: 4514; Percent complete: 90.3%; Average loss: 1.9934
Iteration: 4515; Percent complete: 90.3%; Average loss: 1.7610
Iteration: 4516; Percent complete: 90.3%; Average loss: 1.6804
Iteration: 4517; Percent complete: 90.3%; Average loss: 1.6126
Iteration: 4518; Percent complete: 90.4%; Average loss: 1.6720
Iteration: 4519; Percent complete: 90.4%; Average loss:

Iteration: 4636; Percent complete: 92.7%; Average loss: 1.5796
Iteration: 4637; Percent complete: 92.7%; Average loss: 1.7366
Iteration: 4638; Percent complete: 92.8%; Average loss: 1.9772
Iteration: 4639; Percent complete: 92.8%; Average loss: 1.5735
Iteration: 4640; Percent complete: 92.8%; Average loss: 1.8758
Iteration: 4641; Percent complete: 92.8%; Average loss: 1.7183
Iteration: 4642; Percent complete: 92.8%; Average loss: 1.8577
Iteration: 4643; Percent complete: 92.9%; Average loss: 1.5958
Iteration: 4644; Percent complete: 92.9%; Average loss: 1.6972
Iteration: 4645; Percent complete: 92.9%; Average loss: 1.7854
Iteration: 4646; Percent complete: 92.9%; Average loss: 1.8510
Iteration: 4647; Percent complete: 92.9%; Average loss: 1.6089
Iteration: 4648; Percent complete: 93.0%; Average loss: 1.8399
Iteration: 4649; Percent complete: 93.0%; Average loss: 1.6119
Iteration: 4650; Percent complete: 93.0%; Average loss: 1.9192
Iteration: 4651; Percent complete: 93.0%; Average loss:

Iteration: 4768; Percent complete: 95.4%; Average loss: 1.6700
Iteration: 4769; Percent complete: 95.4%; Average loss: 1.7700
Iteration: 4770; Percent complete: 95.4%; Average loss: 1.7199
Iteration: 4771; Percent complete: 95.4%; Average loss: 1.7218
Iteration: 4772; Percent complete: 95.4%; Average loss: 1.5558
Iteration: 4773; Percent complete: 95.5%; Average loss: 1.7413
Iteration: 4774; Percent complete: 95.5%; Average loss: 1.6994
Iteration: 4775; Percent complete: 95.5%; Average loss: 1.7195
Iteration: 4776; Percent complete: 95.5%; Average loss: 1.5052
Iteration: 4777; Percent complete: 95.5%; Average loss: 1.9453
Iteration: 4778; Percent complete: 95.6%; Average loss: 1.5991
Iteration: 4779; Percent complete: 95.6%; Average loss: 1.5945
Iteration: 4780; Percent complete: 95.6%; Average loss: 1.5446
Iteration: 4781; Percent complete: 95.6%; Average loss: 1.7581
Iteration: 4782; Percent complete: 95.6%; Average loss: 1.6370
Iteration: 4783; Percent complete: 95.7%; Average loss:

Iteration: 4900; Percent complete: 98.0%; Average loss: 1.9133
Iteration: 4901; Percent complete: 98.0%; Average loss: 1.8101
Iteration: 4902; Percent complete: 98.0%; Average loss: 1.5409
Iteration: 4903; Percent complete: 98.1%; Average loss: 1.6815
Iteration: 4904; Percent complete: 98.1%; Average loss: 1.6747
Iteration: 4905; Percent complete: 98.1%; Average loss: 1.6413
Iteration: 4906; Percent complete: 98.1%; Average loss: 1.6212
Iteration: 4907; Percent complete: 98.1%; Average loss: 1.5920
Iteration: 4908; Percent complete: 98.2%; Average loss: 1.3516
Iteration: 4909; Percent complete: 98.2%; Average loss: 1.5738
Iteration: 4910; Percent complete: 98.2%; Average loss: 1.8091
Iteration: 4911; Percent complete: 98.2%; Average loss: 1.4559
Iteration: 4912; Percent complete: 98.2%; Average loss: 1.5871
Iteration: 4913; Percent complete: 98.3%; Average loss: 1.6195
Iteration: 4914; Percent complete: 98.3%; Average loss: 1.7254
Iteration: 4915; Percent complete: 98.3%; Average loss:

## RUN AN PLAY

In [20]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting
evaluateInput(encoder, decoder, searcher, voc)

> hello
Bot: hi . that s right . that start over there . ?
> how are you ?
Bot: i m fine celia . yourself . yourself . ! yourself
> thanks
Bot: you re welcome . your new bill please . it . ?
> q


## Reinforcement learning

Forward and backward models 

In [31]:
forward_encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
forward_decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
forward_encoder = encoder.to(device)
forward_decoder = decoder.to(device)

In [32]:
backward_encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
backward_decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
backward_encoder = encoder.to(device)
backward_decoder = decoder.to(device)

Training step for a single iteration 

In [34]:
def RL(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, batch_size):
    # Set device options
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    lengths = lengths.to("cpu")
    
    #Initialize variables
    loss=0
    #print_losses = []
    response=[]
    
    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
    
    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)
    
    
    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

     # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
     # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            #print_losses.append(mask_loss.item() * nTotal)
    else:
         for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            #print_losses.append(mask_loss.item() * nTotal)
            
            #ni or decoder_output
            response.append(topi[0][0])
            
    return (loss, max_target_len, response)

Let's define the rewards method for each action.

It is composed of 3 different types of rewards :
- `ease of answering` which is basically saying that a turn generated by a machineshould be easy to respond to. 
- `information flow` : We  want  the  agent  to  contribute new information at each turn to keep the dialogue moving and avoid repetitive sequences. 
- `semantic coherence` measures the adequacy of responses to avoid situations in which the generated replies are highly rewarded but are ungrammatical or not coherent. 

In [24]:
dull_responses = ["I don't know what you're talking about.", "I don't know.", 
 "You don't know.", "You know what I mean.", "I know what you mean.", 
 "You know what I'm saying.", "You don't know anything."]

In [25]:
def easeOfAnswering():
    NS=len(dull_responses)
    r1=0
    for d in dull_reponses:
        forward_loss, forward_len = RL()
        # log (1/P(a|s)) = CE  --> log(P(a | s)) = - CE
        if forward_len > 0:
            r1 -= forward_loss / forward_len
    if len(dull_responses) > 0:
        r1 = r1 / NS
    return r1

In [27]:
def informationFlow():
    r2=0
    if(len(responses) > 2):
        #2 representations obtained from the encoder for two consecutive turns pi and pi+1
        h_pi = responses[-3]
        h_pi1 = responses[-1]
        # length of the two vector might not match
        min_length = min(len(h_pi), len(h_pi+1))
        h_pi = h_pi[:min_length]
        h_pi1 = h_pi1[:min_length]
        #cosine similarity 
        cos_sim = 1 - scipy.spatial.distance.cosine(vec_a, vec_b)
        
        #Handle negative cos_sim
        if cos_sim <= 0:
            r2 = - cos_sim
        else:
            r2 = - np.log(cos_sim)
    return r2

In [None]:
def semanticCoherence():
    r3 = 0
    forward_loss, forward_len = RL()
    backward_loss, backward_len = RL()
    if forward_len > 0:
        r3 += forward_loss / forward_len
    if backward_len > 0:
        r3+= backward_loss / backward_len
    return r3

In [None]:
l1=0.25
l2=0.25
l3=0.5

In [None]:
def calculate_rewards(input_var, target_var, forward_encoder, forward_decoder, backward_encoder, backward_decoder, teacher_forcing_ratio):
    #rewards per episode
    ep_rewards = []
    #indice of current episode
    ep_num = []
    #list of responses
    responses = []
    #input of current episode
    ep_input = 0
    #target of current episode
    ep_target = 0
    
    #ep_num bounded -> to redefine (MEDIUM POST)
    while (ep_num <= 10):
        #generate current response with the forward model
        curr_response = RLStep()
        
        #Break if :
        # 1 -> dull response
        # 2 -> response is less than MIN_LENGTH
        # 3 -> repetition ie curr_response in responses
        if((len(curr_reponse) < MIN_LENGTH) or (curr_response in dull_responses) or (curr_response in responses)):
            break
            
        #We can add the response to responses list
        curr_response = Variable(torch.LongTensor(curr_response), requires_grad=False).view(-1, 1)
        curr_response = curr_response.cuda() if use_cuda else curr_response
        responses.append(curr_response)
        
        #Ease of answering
        r1 = easeOfAnswering()
        
        #Information flow
        r2 = informationFlow()
        
        #Semantic coherence
        r3 = semanticCoherence()
        
        #Final reward as a weighted sum of rewards
        r = l1*r1 + l2*r2 + l3*r3
        
        #Add the current reward to the list
        ep_rewards.append(r)
        
        #Next input is the current response
        ep_input = curr_response
        #Next target -> dummy
        ep_target = 0
        
        #Turn off the teacher forcing  after first iteration -> dummy target
        teacher_forcing_ratio = 0
        ep_num +=1
        
    #Take the mean of the episodic rewards
    return np.mean(ep_rewards) if len(ep_rewards) > 0 else 0 