In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [2]:
CUDA=torch.cuda.is_available()


In [3]:
device=torch.device('cuda' if CUDA else 'cpu')
device
#cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu

device(type='cuda')

### Part 1 : Data Processing

In [4]:
cwd=os.getcwd()
cwd

'C:\\Users\\jashj\\Desktop\\PyTorch'

In [5]:
lists=os.listdir()
lists

['.ipynb_checkpoints',
 'Chatbot Implementation.ipynb',
 'cornell movie-dialogs corpus',
 'cornell_movie_dialogs_corpus.zip',
 'PyTorch Basics.ipynb',
 '__MACOSX']

In [6]:
corpus = os.path.join(cwd,lists[2])

In [7]:
corpus

'C:\\Users\\jashj\\Desktop\\PyTorch\\cornell movie-dialogs corpus'

In [8]:
corpus_dir=os.listdir(corpus)
corpus_dir

['.DS_Store',
 'chameleons.pdf',
 'movie_characters_metadata.txt',
 'movie_conversations.txt',
 'movie_lines.txt',
 'movie_titles_metadata.txt',
 'raw_script_urls.txt',
 'README.txt']

In [9]:
lines_filepath=os.path.join(corpus,corpus_dir[4])
conv_filepath=os.path.join(corpus,corpus_dir[3])

In [10]:
lines_filepath

'C:\\Users\\jashj\\Desktop\\PyTorch\\cornell movie-dialogs corpus\\movie_lines.txt'

In [11]:
# Visualize some lines
with open(lines_filepath,'r',encoding='iso-8859-1') as file:
    movie_lines = file.readlines()

In [12]:
movie_lines[1].strip().split(' +++$+++ ')

['L1044', 'u2', 'm0', 'CAMERON', 'They do to!']

In [13]:
# Split each line of the file into a dictionary of fields (LineID, CharacterID, MovieID, character, text)
line_fields = ['LineID', 'CharacterID', 'MovieID', 'character', 'text']
lines={}



In [14]:
for line in movie_lines:
    values = line.split(' +++$+++ ')

    # Extract fields
    lineObj = {}
    for i , field in enumerate(line_fields):
        lineObj[field]=values[i]
    lines[lineObj['LineID']]=lineObj

In [15]:
lines

{'L1045': {'LineID': 'L1045',
  'CharacterID': 'u0',
  'MovieID': 'm0',
  'character': 'BIANCA',
  'text': 'They do not!\n'},
 'L1044': {'LineID': 'L1044',
  'CharacterID': 'u2',
  'MovieID': 'm0',
  'character': 'CAMERON',
  'text': 'They do to!\n'},
 'L985': {'LineID': 'L985',
  'CharacterID': 'u0',
  'MovieID': 'm0',
  'character': 'BIANCA',
  'text': 'I hope so.\n'},
 'L984': {'LineID': 'L984',
  'CharacterID': 'u2',
  'MovieID': 'm0',
  'character': 'CAMERON',
  'text': 'She okay?\n'},
 'L925': {'LineID': 'L925',
  'CharacterID': 'u0',
  'MovieID': 'm0',
  'character': 'BIANCA',
  'text': "Let's go.\n"},
 'L924': {'LineID': 'L924',
  'CharacterID': 'u2',
  'MovieID': 'm0',
  'character': 'CAMERON',
  'text': 'Wow\n'},
 'L872': {'LineID': 'L872',
  'CharacterID': 'u0',
  'MovieID': 'm0',
  'character': 'BIANCA',
  'text': "Okay -- you're gonna need to learn how to lie.\n"},
 'L871': {'LineID': 'L871',
  'CharacterID': 'u2',
  'MovieID': 'm0',
  'character': 'CAMERON',
  'text': 'No

In [16]:
conv_filepath

'C:\\Users\\jashj\\Desktop\\PyTorch\\cornell movie-dialogs corpus\\movie_conversations.txt'

In [17]:
# Visualize some lines
with open(conv_filepath,'r',encoding='iso-8859-1') as file:
    conv_lines = file.readlines()

In [18]:
conv_lines[1].strip().split('+++$+++')

['u0 ', ' u2 ', ' m0 ', " ['L198', 'L199']"]

In [19]:
# Group fields of lines from 'LoadLines' into conversations based on movie_conversations.txt
conv_fields = ['Character1ID', 'Character2ID', 'MovieID', 'UtteranceIDs']
conversations = []

with open(conv_filepath,'r',encoding='iso-8859-1') as f:
    for line in f:
        values=line.split(' +++$+++ ')
        
        #Extract fields
        
        convObj={}
        
        for i , field in enumerate(conv_fields):
            convObj[field]=values[i]
        
        # Convert string result from split to list since convObj['UtteranceIDs']=='['L198', 'L199']'
        LineIDs = eval(convObj['UtteranceIDs'])
        
        #Reassemble lines
        
        convObj['Lines']=[]
        
        for lineid in LineIDs:
            convObj['Lines'].append(lines[lineid])
        conversations.append(convObj)


In [20]:
conversations

[{'Character1ID': 'u0',
  'Character2ID': 'u2',
  'MovieID': 'm0',
  'UtteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
  'Lines': [{'LineID': 'L194',
    'CharacterID': 'u0',
    'MovieID': 'm0',
    'character': 'BIANCA',
    'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
   {'LineID': 'L195',
    'CharacterID': 'u2',
    'MovieID': 'm0',
    'character': 'CAMERON',
    'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
   {'LineID': 'L196',
    'CharacterID': 'u0',
    'MovieID': 'm0',
    'character': 'BIANCA',
    'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
   {'LineID': 'L197',
    'CharacterID': 'u2',
    'MovieID': 'm0',
    'character': 'CAMERON',
    'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]},
 {'Character1ID': 'u0',
  'Character2ID': 'u2',
  'MovieID': 'm0',
  'Uttera

In [21]:
conv_fields

['Character1ID', 'Character2ID', 'MovieID', 'UtteranceIDs']

In [22]:
values

['u9030', 'u9034', 'm616', "['L666520', 'L666521', 'L666522']\n"]

In [23]:
conversations[0]['Lines'][0]['text'].strip()

'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.'

In [24]:
#Extract pair sentences from conversations
qa_pairs = []

for conversation in conversations:
    # Iterate over all the lines of the conversation
    for i in range(len(conversation['Lines'])-1):
        inputLine=conversation['Lines'][i]['text'].strip()
        targetLine=conversation['Lines'][i+1]['text'].strip()
        # Filter wromg samples if one of the lists is empty
        if inputLine and targetLine:
            qa_pairs.append([inputLine,targetLine])

In [25]:
len(qa_pairs)

221282

In [26]:
qa_pairs[9]

['Gosh, if only we could find Kat a boyfriend...', 'Let me see what I can do.']

In [27]:
# Define path for new file

datafile = os.path.join(lists[2],'formatted_movie_lines.txt')
delimiter = '\t'

# Unescape the delimiter
delimiter = str(codecs.decode(delimiter,'unicode_escape'))

# Write new csv file
print('\n Writing newly formatted file')

with open(datafile , 'w' , encoding = 'utf-8') as outputfile:
    writer = csv.writer(outputfile , delimiter=delimiter)
    
    for pair in qa_pairs:
        writer.writerow(pair)
        
print('Done writing to file')


 Writing newly formatted file
Done writing to file


In [28]:
# Visualize some lines

datafile = os.path.join(lists[2],'formatted_movie_lines.txt')

with open(datafile,'rb') as file:
    lines = file.readlines()
for line in lines[:16]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\r\n"
b'Why?\tU

In [29]:
PAD_token = 0  #Used for padding short sentences
SOS_token = 1  #Start-of-sentence token <START>
EOS_token = 2  #End-of-sentence token <END>

In [30]:
class Vocabulary:
    def __init__(self, name):
        self.name=name
        self.word2index = {}
        self.word2count = {}  # Stores the count of the word
        self.index2word = {PAD_token:'PAD' , SOS_token:'SOS' , EOS_token:'EOS'}
        self.num_words = 3 # Count SOS, EOS, PAD
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
        
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words+=1
        else:
            self.word2count[word]+=1
            
            
    # Remove words below a certain count threshold
    
    def trim(self, min_count):
        keep_words = []
        for k,v in self.word2count.items():
            if v>= min_count:
                keep_words.append(k)
        
        print()
        
        #Reiniitalize dictionaries
        
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token:'PAD' , SOS_token:'SOS' , EOS_token:'EOS'}
        self.num_words = 3 # Count default tokens
        
        for word in keep_words:
            self.addWord(word)
        

In [31]:
# Turn a Unicode string to plain ASCII
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c)!='Mn')

In [32]:
# Function testing
unicodeToAscii('Excusez-moi, où est la gare ?')

'Excusez-moi, ou est la gare ?'

In [33]:
# Lowercase, trim white spaces, lines..etc, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Replace any .!? by a whitespace + the character --> '!'=' !'
    # \1 means the first bracketed group --> [,!?]
    # r is to not consider \1 as a character (r to escape a backslash).    
    s = re.sub(r'([.!?])',r' \1',s)

    # + means one or more
    # Remove any character other than a-zA-Z.!?
    s = re.sub(r'[^a-zA-Z,.!?]+',r' ',s)
    
    # Remove sequence of whitespace characters
    s = re.sub(r'\s+',r' ',s).strip()
    return s

In [34]:
normalizeString("aa123aa!s's  dd?")

'aa aa !s s dd ?'

In [35]:
lists[2]

'cornell movie-dialogs corpus'

In [36]:
datafile

'cornell movie-dialogs corpus\\formatted_movie_lines.txt'

In [37]:
# Read the file and split into lines
print('Reading and processing file...')
lines = open(datafile, encoding='utf-8').read().strip().split('\n')

# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]
print('Done Reading!')

voc=Vocabulary(lists[2])

Reading and processing file...
Done Reading!


In [38]:
pairs[0]

['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .',
 'well, i thought we d start with pronunciation, if that s okay with you .']

In [39]:
lines[0]

"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you."

In [40]:
len(pairs)

442563

In [41]:
MAX_LENGTH=10
# Returns True if both sentences in a pair are under the MAX_LENGTH threshold

def filterPair(p):
    # Input sequences need to preserve the last word for EOS_token
    return len(p[0].split())<MAX_LENGTH and len(p[1].split())<MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


In [42]:
pairs[0:4]

[['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .',
  'well, i thought we d start with pronunciation, if that s okay with you .'],
 [''],
 ['well, i thought we d start with pronunciation, if that s okay with you .',
  'not the hacking and gagging and spitting part . please .'],
 ['']]

In [43]:
pairs[1]

['']

In [44]:
pairs = [pair for pair in pairs if len(pair)>1]
print('pairs/conversations in the dataset',len(pairs))
pairs=filterPairs(pairs)
print('pairs/conversations after filtering',len(pairs))

pairs/conversations in the dataset 221282
pairs/conversations after filtering 64251


In [45]:
pairs[0:4]

[['there .', 'where ?'],
 ['you have my word . as a gentleman', 'you re sweet .'],
 ['hi .', 'looks like things worked out tonight, huh ?'],
 ['you know chastity ?', 'i believe we share an art instructor']]

In [46]:
pairs[1]

['you have my word . as a gentleman', 'you re sweet .']

In [47]:
# Loop through each pair and add the question and reply sentence to the vocabulary

for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
    
print('Counter words:',voc.num_words)

Counter words: 21076


In [48]:
pairs

[['there .', 'where ?'],
 ['you have my word . as a gentleman', 'you re sweet .'],
 ['hi .', 'looks like things worked out tonight, huh ?'],
 ['you know chastity ?', 'i believe we share an art instructor'],
 ['have fun tonight ?', 'tons'],
 ['well, no . . .', 'then that s all you had to say .'],
 ['then that s all you had to say .', 'but'],
 ['but', 'you always been this selfish ?'],
 ['do you listen to this crap ?', 'what crap ?'],
 ['what good stuff ?', 'the real you .'],
 ['the real you .', 'like my fear of wearing pastels ?'],
 ['wow', 'let s go .'],
 ['she okay ?', 'i hope so .'],
 ['they do to !', 'they do not !'],
 ['did you change your hair ?', 'no .'],
 ['no .', 'you might wanna think about it'],
 ['who ?', 'joey .'],
 ['great', 'would you mind getting me a drink, cameron ?'],
 ['it s more', 'expensive ?'],
 ['hey, sweet cheeks .', 'hi, joey .'],
 ['where ve you been ?', 'nowhere . . . hi, daddy .'],
 ['you are so completely unbalanced .', 'can we go now ?'],
 ['what ?', 'in t

In [49]:
# Trimming rare words

MIN_COUNT = 3 # Minimum word count threshold for trimming

def trimRareWords(voc, pair, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
                
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
                
        # Only keep pairs that do not contain trimmed words
        
        if keep_input and keep_output:
            keep_pairs.append(pair)
            
    return keep_pairs

# Trim voc and pairs
pairs=trimRareWords(voc, pairs, MIN_COUNT)




In [50]:
pairs

[['there .', 'where ?'],
 ['you have my word . as a gentleman', 'you re sweet .'],
 ['hi .', 'looks like things worked out tonight, huh ?'],
 ['have fun tonight ?', 'tons'],
 ['well, no . . .', 'then that s all you had to say .'],
 ['then that s all you had to say .', 'but'],
 ['but', 'you always been this selfish ?'],
 ['do you listen to this crap ?', 'what crap ?'],
 ['what good stuff ?', 'the real you .'],
 ['wow', 'let s go .'],
 ['she okay ?', 'i hope so .'],
 ['they do to !', 'they do not !'],
 ['did you change your hair ?', 'no .'],
 ['no .', 'you might wanna think about it'],
 ['who ?', 'joey .'],
 ['great', 'would you mind getting me a drink, cameron ?'],
 ['it s more', 'expensive ?'],
 ['where ve you been ?', 'nowhere . . . hi, daddy .'],
 ['what ?', 'in th . for a month'],
 ['in th . for a month', 'why ?'],
 ['why ?', 'he was, like, a total babe'],
 ['he was, like, a total babe', 'but you hate joey'],
 ['you looked beautiful last night, you know .', 'so did you'],
 ['let go 

In [51]:
voc.num_words

8549

In [52]:
len(pairs)

50971

### Preparing the data

In [53]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

In [54]:
pairs[1][0]

'you have my word . as a gentleman'

In [55]:
len([word for word in pairs[1][0].split(' ')])

8

In [56]:
# Test the function
indexesFromSentence(voc,pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [57]:
len(indexesFromSentence(voc,pairs[1][0]))

9

In [58]:
voc.word2count.items()



In [59]:
voc.word2index

{'there': 3,
 '.': 4,
 'where': 5,
 '?': 6,
 'you': 7,
 'have': 8,
 'my': 9,
 'word': 10,
 'as': 11,
 'a': 12,
 'gentleman': 13,
 're': 14,
 'sweet': 15,
 'hi': 16,
 'looks': 17,
 'like': 18,
 'things': 19,
 'worked': 20,
 'out': 21,
 'tonight,': 22,
 'huh': 23,
 'know': 24,
 'i': 25,
 'believe': 26,
 'we': 27,
 'share': 28,
 'an': 29,
 'art': 30,
 'fun': 31,
 'tonight': 32,
 'tons': 33,
 'well,': 34,
 'no': 35,
 'then': 36,
 'that': 37,
 's': 38,
 'all': 39,
 'had': 40,
 'to': 41,
 'say': 42,
 'but': 43,
 'always': 44,
 'been': 45,
 'this': 46,
 'selfish': 47,
 'do': 48,
 'listen': 49,
 'crap': 50,
 'what': 51,
 'good': 52,
 'stuff': 53,
 'the': 54,
 'real': 55,
 'fear': 56,
 'of': 57,
 'wearing': 58,
 'wow': 59,
 'let': 60,
 'go': 61,
 'she': 62,
 'okay': 63,
 'hope': 64,
 'so': 65,
 'they': 66,
 '!': 67,
 'not': 68,
 'did': 69,
 'change': 70,
 'your': 71,
 'hair': 72,
 'might': 73,
 'wanna': 74,
 'think': 75,
 'about': 76,
 'it': 77,
 'who': 78,
 'joey': 79,
 'great': 80,
 'would': 

In [60]:
# Define some samples for testing

inp = []
out = []

for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
    
print(inp)
print(len(inp))

indexes = [indexesFromSentence(voc, sentence) for sentence in inp]
indexes

['there .', 'you have my word . as a gentleman', 'hi .', 'have fun tonight ?', 'well, no . . .', 'then that s all you had to say .', 'but', 'do you listen to this crap ?', 'what good stuff ?', 'wow']
10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 32, 6, 2],
 [34, 35, 4, 4, 4, 2],
 [36, 37, 38, 39, 7, 40, 41, 42, 4, 2],
 [43, 2],
 [48, 7, 49, 41, 46, 50, 6, 2],
 [51, 52, 53, 6, 2],
 [59, 2]]

In [61]:
def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [62]:
a=[1,2,3,'a']
list(itertools.zip_longest(a, fillvalue='qq'))

[(1,), (2,), (3,), ('a',)]

In [63]:
b=[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 32, 6, 2],
 [34, 35, 4, 4, 4, 2],
 [36, 37, 38, 39, 7, 40, 41, 42, 4, 2],
 [43, 2],
 [48, 7, 49, 41, 46, 50, 6, 2],
 [51, 52, 53, 6, 2],
 [59, 2]]

In [64]:
list(itertools.zip_longest(*b, fillvalue='qq'))

[(3, 7, 16, 8, 34, 36, 43, 48, 51, 59),
 (4, 8, 4, 31, 35, 37, 2, 7, 52, 2),
 (2, 9, 2, 32, 4, 38, 'qq', 49, 53, 'qq'),
 ('qq', 10, 'qq', 6, 4, 39, 'qq', 41, 6, 'qq'),
 ('qq', 4, 'qq', 2, 4, 7, 'qq', 46, 2, 'qq'),
 ('qq', 11, 'qq', 'qq', 2, 40, 'qq', 50, 'qq', 'qq'),
 ('qq', 12, 'qq', 'qq', 'qq', 41, 'qq', 6, 'qq', 'qq'),
 ('qq', 13, 'qq', 'qq', 'qq', 42, 'qq', 2, 'qq', 'qq'),
 ('qq', 2, 'qq', 'qq', 'qq', 4, 'qq', 'qq', 'qq', 'qq'),
 ('qq', 'qq', 'qq', 'qq', 'qq', 2, 'qq', 'qq', 'qq', 'qq')]

In [65]:
b

[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 32, 6, 2],
 [34, 35, 4, 4, 4, 2],
 [36, 37, 38, 39, 7, 40, 41, 42, 4, 2],
 [43, 2],
 [48, 7, 49, 41, 46, 50, 6, 2],
 [51, 52, 53, 6, 2],
 [59, 2]]

In [66]:
leng = [len(ind) for ind in indexes]
max(leng)

10

In [67]:
leng

[3, 9, 3, 5, 6, 10, 2, 8, 5, 2]

In [68]:
indexes

[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 32, 6, 2],
 [34, 35, 4, 4, 4, 2],
 [36, 37, 38, 39, 7, 40, 41, 42, 4, 2],
 [43, 2],
 [48, 7, 49, 41, 46, 50, 6, 2],
 [51, 52, 53, 6, 2],
 [59, 2]]

In [69]:
#Test the function
test_result = zeroPadding(indexes)
print(len(test_result)) # The max length is now the number of rows or the batch size
test_result

10


[(3, 7, 16, 8, 34, 36, 43, 48, 51, 59),
 (4, 8, 4, 31, 35, 37, 2, 7, 52, 2),
 (2, 9, 2, 32, 4, 38, 0, 49, 53, 0),
 (0, 10, 0, 6, 4, 39, 0, 41, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 46, 2, 0),
 (0, 11, 0, 0, 2, 40, 0, 50, 0, 0),
 (0, 12, 0, 0, 0, 41, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 42, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [70]:
def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [71]:
binary_result = binaryMatrix(test_result)
binary_result

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 0, 1, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]

In [72]:
test_result

[(3, 7, 16, 8, 34, 36, 43, 48, 51, 59),
 (4, 8, 4, 31, 35, 37, 2, 7, 52, 2),
 (2, 9, 2, 32, 4, 38, 0, 49, 53, 0),
 (0, 10, 0, 6, 4, 39, 0, 41, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 46, 2, 0),
 (0, 11, 0, 0, 2, 40, 0, 50, 0, 0),
 (0, 12, 0, 0, 0, 41, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 42, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [73]:
# Returns padded input sequence tensor as well as a tensor of lengths for each of the sequences in the batch
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

In [74]:
# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len


In [75]:
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [76]:
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[ 825,   25,  126,    7, 1456],
        [  25,  295,   71, 1392,   77],
        [ 121,    7, 2931,   37,    6],
        [  41,  680, 1350,  183,    2],
        [ 377,  212,  185,    6,    0],
        [  54,    4,    4,    2,    0],
        [2678,    2,    2,    0,    0],
        [   4,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([9, 7, 7, 6, 4])
target_variable: tensor([[ 115,    9,   51,  227, 1456],
        [  12,  751,   38,    4,   77],
        [1538,  120,   37,    2,    4],
        [ 103,   27,  100,    0,    2],
        [ 142, 7025,    6,    0,    0],
        [  41,   46,    2,    0,    0],
        [  61,  693,    0,    0,    0],
        [   4,    6,    0,    0,    0],
        [   2,    2,    0,    0,    0]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True, False,  True],
        [ True,  True,  Tr