# Building a Chatbot : PyTorch
__NOTE:__ This code has been adopted from https://pytorch.org/tutorials/chatbot_tutorial.html. We will demonstrate more details and observe the output step-by-step to have a deeper understanding. Extra diagrams to dive deeper into the architectures used will also be provided!

In [1]:
# importing the libraries
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [2]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

## Part 1: Data Processing

In [3]:
lines_filepath = os.path.join("cornell movie-dialogs corpus", "movie_lines.txt")
conv_filepath = os.path.join("cornell movie-dialogs corpus", "movie_conversations.txt")


In [4]:
# Visual some lines
with open(lines_filepath, 'r') as file:
    lines = file.readlines()
for line in lines[:8]:
    print (line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [5]:
# Split each line of the file into a dictionary of fields (lineID, characterID, moiveID, character, text)

line_fields = ['lineID', 'characterID', 'movieID', 'character', 'text']
lines = {}
with open(lines_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        
        values = line.split(' +++$+++ ')
        
        # Extract Field
        lineObj = {}
        for i, field in enumerate(line_fields):
            lineObj[field] = values[i]
        
        lines[lineObj['lineID']] = lineObj

In [6]:
list(lines.items())[0]

('L1045',
 {'lineID': 'L1045',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': 'They do not!\n'})

In [7]:
# Groups fields of lines from 'loadLines' into conversations based on 'movie_conversations.txt'

conv_fields = ['character1ID', 'character2ID', 'movieID', 'utteranceIDs']
conversations = []

with open(conv_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        
        values = line.split(' +++$+++ ')
        #Extract field 
        
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field] = values[i]
            
        # convert string result from split to list, since convObj['utteranceIDs'] == "['L598485', 'L598486']"
        lineIds = eval(convObj['utteranceIDs'])
        
        #Reassamble lines
        convObj['lines'] = []
        for lineId in lineIds:
            convObj['lines'].append(lines[lineId])
        conversations.append(convObj)

In [8]:
# Extracts pairs of sentecnces from conversations
qa_pairs = []
for conversation in conversations:
    
    for i in range(len(conversation['lines']) - 1 ):
        inputLine = conversation['lines'][i]['text'].strip()
        targetLine = conversation['lines'][i+1]['text'].strip()
        
        #Filter wrong samples (if one of the lists is empty)
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])

In [9]:
# Define path to new file
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')
delimiter = '\t'

#Unescape the delimeter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Write new csv file
print ("\nWriting newly fromatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)
        
print ("Done writing to file")


Writing newly fromatted file...
Done writing to file


In [10]:
# Visualize some lines
datafile = os.path.join('cornell movie-dialogs corpus','formatted_movie_lines.txt')
with open(datafile, 'rb') as file:
    lines = file.readlines()
for line in lines[:8]:
    print (line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\r\n"
b'Why?\tU

### Processing the Words

In [11]:
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token <START>
EOS_token = 2 # End-of-sentence token <END>

class Vocabulary:
    
    def __init__(self, name):
        
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: 'PAD', SOS_token: 'SOS', EOS_token: 'EOS'}
        self.num_words = 3 # Count SOS, EOS, PAD
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self, word):
        
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
            
        else:
            self.word2count[word] += 1
            
    # Remove the words below a certain count threshold        
    def trim(self, min_count):
        
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
                
        print ('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens
          
        for word in keep_words:
            self.addWord(word)
        

In [12]:
# Turn a Unicode string to plain ASCII and 'NFD' stands for normal form decomposed
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [13]:
unicodeToAscii('Atatürk')

'Ataturk'

In [14]:
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Replace any .!? by a whitespace + the character --> '!' = ' !'. \1 means the first bracketed group --> [.!?]. r is to
    # not consider \1 as a character (r to escape a backslash).
    s = re.sub(r'([.!?])',r' \1', s)
    
    # Remove any character that is not a sequence of lower or upper case letters. + means one or more
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    # Remove a sequence of whitespace characters
    s = re.sub(r'\s+', r' ', s).strip()
    return s

In [15]:
# Test the function 
normalizeString("aa123aa!s's  dd?")

'aa aa !s s dd ?'

In [16]:
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')

# Read the file and split it into lines
print("Reading and processing file... Please Wait")
lines = open(datafile, encoding='utf-8').read().strip().split('\n')

# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]

print ('Done Reading!')

voc = Vocabulary('cornell movie-dialogs corpus')

Reading and processing file... Please Wait
Done Reading!


## Filtering the Text

In [17]:
# Returns True if both sentences in a pair 'p' are under the MAX_LENGTH threshold
MAX_LENGTH = 10 # Maximum sentence length to consider max words)

def filterPair(p):
    
    # Input sequences need to presever the last word for EOS taken
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


In [18]:
pairs = [pair for pair in pairs if len(pair)> 1]
print ("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs = filterPairs(pairs)
print ("After filtering, there are {} pairs/conversations in the dataset".format(len(pairs)))


There are 221282 pairs/conversations in the dataset
After filtering, there are 64271 pairs/conversations in the dataset


## Getting Rid of the Rare Words

In [19]:
# Loop through each pair of and add the quesiton and reply sentence to the vocabulary

for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
    
print("Counted words:", voc.num_words)

for pair in pairs[:10]:
    print (pair)

Counted words: 18008
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [20]:
MIN_COUNT = 3 # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
                
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
                
        # Only keep pairs that do not contain trimmed word(s) in their input input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pairs)
            
    print ("Trimmed from {} pairs to {}, {:4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs
# trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.827200 of total


## Preparing the Data

In [20]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


In [21]:
pairs[1][0]

'you have my word . as a gentleman'

In [22]:
# Test the function
indexesFromSentence(voc,pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [23]:
# Define some samples for testing
inp = []
out = []

for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
    
print (inp)
print (len(inp))

indexes = [indexesFromSentence(voc, sentence) for sentence in inp]
indexes

['there .', 'you have my word . as a gentleman', 'hi .', 'you know chastity ?', 'have fun tonight ?', 'well no . . .', 'then that s all you had to say .', 'but', 'do you listen to this crap ?', 'what good stuff ?']
10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [7, 24, 25, 6, 2],
 [8, 33, 22, 6, 2],
 [35, 36, 4, 4, 4, 2],
 [37, 38, 39, 40, 7, 41, 42, 43, 4, 2],
 [44, 2],
 [49, 7, 50, 42, 47, 51, 6, 2],
 [52, 53, 54, 6, 2]]

## Understading the Zip Function

#### code snippets before _zeroPadding__ are rough.

In [24]:
a = ['A','B','C'] # in case of ['A','B','C', 'D', 'E'] 'D' and 'E' will be ignored.
b = [1,2,3]
list(zip(a,b))

[('A', 1), ('B', 2), ('C', 3)]

In [25]:
a = ['A','B','C', 'D', 'E'] # in case of ['A','B','C', 'D', 'E'] 'D' and 'E' will not be ignored.
b = [1,2,3]
list(itertools.zip_longest(a,b))

[('A', 1), ('B', 2), ('C', 3), ('D', None), ('E', None)]

In [26]:
a = [[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [7, 24, 25, 6, 2],
 [8, 33, 22, 6, 2],
 [35, 36, 4, 4, 4, 2],
 [37, 38, 39, 40, 7, 41, 42, 43, 4, 2],
 [44, 2],
 [49, 7, 50, 42, 47, 51, 6, 2],
 [52, 53, 54, 6, 2]]

list(itertools.zip_longest(*a, fillvalue = 0)) # To fill the None values with the 0 'fillvalue = 0'

[(3, 7, 16, 7, 8, 35, 37, 44, 49, 52),
 (4, 8, 4, 24, 33, 36, 38, 2, 7, 53),
 (2, 9, 2, 25, 22, 4, 39, 0, 50, 54),
 (0, 10, 0, 6, 6, 4, 40, 0, 42, 6),
 (0, 4, 0, 2, 2, 4, 7, 0, 47, 2),
 (0, 11, 0, 0, 0, 2, 41, 0, 51, 0),
 (0, 12, 0, 0, 0, 0, 42, 0, 6, 0),
 (0, 13, 0, 0, 0, 0, 43, 0, 2, 0),
 (0, 2, 0, 0, 0, 0, 4, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 2, 0, 0, 0)]

In [27]:
def zeroPadding(l, fillvalue = 0):    
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [28]:
leng = [len(ind) for ind in indexes]
max(leng)

10

In [29]:
# Test the function
test_result = zeroPadding(indexes)
print (len(test_result)) # The max length is now the number of rows
test_result

10


[(3, 7, 16, 7, 8, 35, 37, 44, 49, 52),
 (4, 8, 4, 24, 33, 36, 38, 2, 7, 53),
 (2, 9, 2, 25, 22, 4, 39, 0, 50, 54),
 (0, 10, 0, 6, 6, 4, 40, 0, 42, 6),
 (0, 4, 0, 2, 2, 4, 7, 0, 47, 2),
 (0, 11, 0, 0, 0, 2, 41, 0, 51, 0),
 (0, 12, 0, 0, 0, 0, 42, 0, 6, 0),
 (0, 13, 0, 0, 0, 0, 43, 0, 2, 0),
 (0, 2, 0, 0, 0, 0, 4, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 2, 0, 0, 0)]

In [30]:
def binaryMatrix(l, value=0):
    
    m = []
    for i , seq in enumerate(l):
        m.append([])
        for token in seq:   
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
                
    return m


In [31]:
binary_result = binaryMatrix(test_result)
binary_result

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
 [0, 1, 0, 1, 1, 1, 1, 0, 1, 1],
 [0, 1, 0, 1, 1, 1, 1, 0, 1, 1],
 [0, 1, 0, 0, 0, 1, 1, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 1, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 1, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]

In [32]:
# Returns padded input sequence tensor and as well as a tensor of lengths for each of the sequences in the batch
def inputVar(l, voc):
    indexes_batch =  [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

In [33]:
# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [34]:
# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    
    # Sort the questions in descending length
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        
        input_batch.append(pair[0])
        output_batch.append(pair[1])
        
    inp, lengths = inputVar(input_batch, voc)
    #assert len(inp) == lengths[0]
    
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len
    
    

In [35]:
# Example for validation 
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print ("input_variable:")
print (input_variable)
print ("lengths:", lengths)
print ("target_variable:")
print (target_variable)
print ("mask:")
print (mask)
print ("max_target_len", max_target_len)


input_variable:
tensor([[   7,   96,   26, 3879,  130],
        [ 124,   28,   99, 4291,   38],
        [ 180,  166,  340,   69,    2],
        [7932,    6,    4,    2,    0],
        [   6,    2,    2,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([6, 5, 5, 4, 3])
target_variable:
tensor([[   28,   351,  4960,   336,    52],
        [   99,     4,   461,     4,    71],
        [  117,     2,     7,     2,     7],
        [  237,     0,   720,     0,   285],
        [   47,     0,    12,     0,   255],
        [12647,     0,  4958,     0,    38],
        [    4,     0,     4,     0,     6],
        [    2,     0,     2,     0,     2]])
mask:
tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1]], dtype=torch.uint8)
max_target_len 8


## Encoder

In [36]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        # Initialize GRU: the input_size and hidden_size parameters are both set to 'hidden_size'
        # because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
        
    def forward(self, input_seq, input_lengths, hidden=None):
        # input_seq: batch of input sentences: shape=(max_length, batch_size)
        # input_lengths: List of sentence lengths corresponding to each sentence in the batch
        # hidden state, of shape: (n_layers x num_directions, batch_size, hidden_size)
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs,hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden
        # outputs: the output features h_t from the last layer of the GRU, for each timestep (sum of bidirectional outputs)
        # outputs: shape=(max_length, batch_size, hidden_size)
        # hidden: hidden state for the last timestep, of shape=(n_layers x num_directions, batch_size, hidden_size)

In [37]:
# Luong attention layer
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        self.hidden_size = hidden_size
        
    def dot_score(self, hidden, encoder_output):
        # Element-Wise Multiply the current target decoder state with the encoder output and sum them
        return torch.sum(hidden * encoder_output, dim=2)
    
    def forward(self, hidden, encoder_outputs):
        # hidden of shape: (1, batch_size, hidden_size)
        # encoder_outputs of shape: (max_length, batch_size, hidden_size) 
        # (1, batch_size, hidden_size) * (max_length, batch_size, hidden_size) = (max_length, batch_size, hidden_size)
        
        # Calculate the attention weights (energies)
        attn_energies = self.dot_score(hidden, encoder_outputs) # (max_length, batch_size)
        # Transpose max_length and batch_size dimentions 
        attn_energies = attn_energies.t()                       # (batch_size, max_length)
        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)     # (batch_size, 1, max_length)

For the decoder we will manually feed our batch one time step in a time. This means that our embedded word tensor and GRU output will both have shape (1, batch_size, hidden_size). The steps are: Get embedding of current input word. Forward through unidirectional GRU. Calculate attention weights from the current GRU output. Multiply attention weights to encoder outputs to get new "weighted sum" context vector. Concatenate weighted context vector and GRU output. Predict next word, and finally Return output and final hidden state.

In [38]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        self.attn = Attn(attn_model, hidden_size)
        
    def forward(self, input_step, last_hidden, encoder_outputs):
        # input_step: one time step (one word) of input sequence batch; shape=(1, batch_size)
        # last_hidden: final hidden layer of GRU; shape=(n_layers x num_directions, batch_size, hidden_size)
        # encoder_outputs: encoder model's output; shape=(max_length, batch_size, hidden_size)
        # Note: we run this one step (word) at a time
        
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # rnn_output of shape: (1, batch, num_direciton * hidden_size)
        # hidden of shape: (num_layers * num_directions, batch, hidden_size)
        
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        # (batch_size, 1, max_length) bmm with (batch_size, max_length, hidden) = (batch_size, 1, hidden)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = f.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden
        # output: softmax normalized tensor giving probabilities of each word being the correct next word in the decoded sequence
        # shape (batch_size, voc.num_words)
        # hidden: final hidden state of GRU; shape=(n_layers x num_directions batch_size, hidden_size)

## We're Done with Building the Architecture, Now Let's move on to the Training code

Since we are dealing with batches of padded sequences, we cannot simply consider all elements of the tensor when calculating loss. We define maskLLLoss to calculate our loss based on our decoder's output tensor, the target tensor, and a binary mask tensor describing the padding of the target tensor. This loss function calculates the average negative log likelihood of the elements that correspond to a 1 in the mask tensor.

In [39]:
def maskNLLLoss(decoder_out, target, mask):
    nTotal = mask.sum
    target = target.view(-1, 1)
    # decoder_out shape: (batch_size, vocab_size), target_size = (batch_size, 1)
    gathered_tensor = torch.gather(decoder_out, 1, target)
    # Calculate the Negative Log Likelihood Loss
    crossEntropy = -torch.log(gathered_tensor)
    # Select the non-zero elements
    loss = loss.mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [None]:
# Visualizing what's happening in one iteration. Only run this for visualization.
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable shape:", input_variable.shape)
print("lengths shape:", lengths.shape)
print("target_variable shape:", target_variable.shape)
print("mask shape:", mask.shape)
print("max_target_len:", max_target_len)

# Define the parameters
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
attn_model = 'dot'
embedding = nn.Embedding(voc.num_words, hidden_size)

# Define the encoder and Decoder
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
encoder = encoder.to(device)
decoder = decoder.to(divice)

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# initialize optimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr = 0.0001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr = 0.0001)
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()

input_variable = input_variable.to(device)
lengths = lengths.to(device)
target_variable = target_variable.to(device)
mask = mask.to(device)

loss = 0
print_losses = []
n_totals = 0

encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
print("Encoder Outputs Shape:", encoder_outputs.shape)
print("Last Encoder Hidden Shape", encoder_hidden.shape)

decoder_input = torch.LongTensor([[SOS_token for _ in range(small_batch_size)]])
decoder_input = decoder_input.to(device)
print("initial Decoder Input Shape:", decoder_input.shape)
print(decoder_input)

# set initial decoder hidden state to the encoder's final hidden state
decoder_hidden = encoder_hidden[:decoder.n_layers]
print("Initial Decoder hidden state shape:", decoder_hidden.shape)
print("\n")
print("-----------------------------------------------------------------------")
print("Now lets look what's happening in every timestep of the GRU!")
print("-----------------------------------------------------------------------")
print("\n")

# Assume we are using Teacher Forcing
