# ChatBot:Data Loading And Preprocessing

In [20]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [21]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

# Data Preprocessing

In [22]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("data", corpus_name)
lines_filepath = os.path.join(corpus,"movie_lines.txt")
conv_filepath = os.path.join(corpus,"movie_conversations.txt")

In [23]:
# visualizing some lines
with open(lines_filepath,'r') as file:
    lines = file.readlines()
for line in lines[:10]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No
L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I'm kidding.  You know how sometimes you just become this "persona"?  And you don't know how to quit?
L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?


In [24]:
# splits each line of the file into a dictionary of fields(lineID,characterID,movieID,character,text)
line_fields = ["lineID","characterID","movieID","character","text"]
lines = {}

with open(lines_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        # Extract fields
        lineObj = {}
        for i, field in enumerate(line_fields):
            lineObj[field] = values[i]
        lines[lineObj['lineID']] = lineObj

In [26]:
# Grouping fields of lines from the above loaded lines into conversation based on "movie_conversations.txt"

conv_fields = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
conversations = []

with open(conv_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        # Extract fields
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field] = values[i]
        lineIds = eval(convObj["utteranceIDs"])
        # Reassemble lines
        convObj["lines"] = []
        for lineId in lineIds:
            convObj["lines"].append(lines[lineId])
        conversations.append(convObj)

In [28]:
# Extracts pairs of sentences from conversations
qa_pairs = []
for conversation in conversations:
    # Iterate over all the lines of the conversation
    for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
        inputLine = conversation["lines"][i]["text"].strip()
        targetLine = conversation["lines"][i+1]["text"].strip()
        # Filter wrong samples (if one of the lists is empty)
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])

In [30]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Writing the conversational pairs into new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)


Writing newly formatted file...


In [31]:
# visualizing some lines
with open(datafile,'rb') as file:
    lines = file.readlines()
for line in lines[:10]:
    print(line.strip())

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you."
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please."
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?"
b"You're asking me out.  That's so cute. What's your name again?\tForget it."
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron."
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does."
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough..."
b'Why?\tUnsolved mystery.  She used to be really po

In [32]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Counting SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Removing words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitializing dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Counting default tokens

        for word in keep_words:
            self.addWord(word)

In [33]:
# Turn a Unicode string to plain ASCII
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [34]:
# Testing the function - eliminates special characters
unicodeToAscii("cédillea")

'cedillea'

In [35]:
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    
    # Unicode string to plain ASCII
    s = unicodeToAscii(s.lower().strip())
    
    # Replacing any .!? by a whitespace plus the character
    # ' \1' means the first bracketed group
    # r is not to consider ' \1' as an individual character
    # r in r" \1" is to esccape the backslash
    s = re.sub(r"([.!?])", r" \1", s)
    
    # Removing any character that is not a sequence of lower or upper case letters
    # + means one or more
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    
    # Removing a sequence of whitespace characters
    s = re.sub(r"\s+", r" ", s).strip()
    
    return s

In [36]:
# Testing the function
normalizeString("aaa1103!?.     aaa'a aaa?")

'aaa ! ? . aaa a aaa ?'

In [37]:
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

# Reading the file and splitting into lines
print("Reading and processing file....Please Wait")
lines = open(datafile,encoding = 'utf-8').read().strip().split('\n')

# Splitting every line into pairs and normalizing them
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]

print("Done Reading!!!")
voc = Vocabulary(corpus)

Reading and processing file....Please Wait
Done Reading!!!


In [39]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    if p:
        return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
    else:
        return False

# Filter pairs using filterPair condition
def filterPairs(pairs):
    pairs = [pair for pair in pairs if pair != ['']]
    return [pair for pair in pairs if filterPair(pair)]

In [40]:
print("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs = filterPairs(pairs)
print("After filtering there are {} pairs/conversations".format(len(pairs)))

There are 442563 pairs/conversations in the dataset
After filtering there are 64271 pairs/conversations


In [42]:
for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])

print("Counted words:", voc.num_words)
for pair in pairs[:10]:
    print(pair)

Counted words: 18008
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [43]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [44]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

In [45]:
indexesFromSentence(voc,pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [46]:
def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [47]:
def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [48]:
# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [49]:
# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [50]:
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  94,   65,   16,  157, 4514],
        [ 117,  702, 1119,  266,    4],
        [   7,   96,    4,    6,    2],
        [ 697,  214,    2,    2,    0],
        [  83,    6,    0,    0,    0],
        [   6,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([7, 6, 4, 4, 3])
target_variable: tensor([[  25,   65,   16,  239,   77],
        [ 697,  517, 4607,  488,   37],
        [   7,   96,    4,  198, 4514],
        [2182,  214,    5,   76,    6],
        [   4,    4,    7,  180,    2],
        [   2,    2,   44,    4,    0],
        [   0,    0,    6,    2,    0],
        [   0,    0,    2,    0,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 1, 0, 0]], dtype=torch.uint8)
max_target_len: 8
