In [2]:
import torch
import torch.nn as nn ##neural network
from torch import optim #optimizers
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [3]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu") # if true  GPU will be used instead of CPU

In [4]:
# creating path for files
lines_filePath = os.path.join("dataset","movie_lines.txt")
conv_filePath = os.path.join("dataset", "movie_conversations.txt")


In [5]:
#print some lines

def printLines(file):
    with open(file,'rb') as file:
        lines=file.readlines()
    for line in lines[:8]: # we are only print 0-7 lines just to visulaise the data
        print(line.strip())
        
printLines(lines_filePath)
printLines(conv_filePath)

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go."
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie."
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No'
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']"
b"u0 +++$+++ u2 +++$+++ m

In [6]:
# Splits each line of the file into a dictionary of fields
def loadLines(file):
    line_fields = ["lineId", "characterId", "movieId","character","text"]
    lines={}
    #if encoding is not done then below spliting will throw error
    with open(file,'r', encoding='iso-8859-1') as f: 
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(line_fields): # enumerate will iterate bubt also keep track of index( i is index here, and filed will be data at that index)
                    lineObj[field] = values[i]
            lines[lineObj['lineId']] = lineObj
    return lines 

In [7]:
# now processing conversation file ehich will tell who is speaking to whom

def loadConv(file):
    conv_filed=["charachterId", "character2Id","movieId","utteranceIds"]
    conversations = []
    with open(file, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            convObj = {}
            for i, field in enumerate(conv_filed):
                    convObj[field] = values[i]
        
            # Convert string to list as utterance is list stored as string e.g "['L598485', 'L598486', ...]")
        
            lineIds = eval(convObj["utteranceIds"])
            
            convObj["lines"]=[]
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
            
    return conversations

In [11]:
loadConv(conv_filePath)

[{'charachterId': 'u0',
  'character2Id': 'u2',
  'movieId': 'm0',
  'utteranceIds': "['L194', 'L195', 'L196', 'L197']\n",
  'lines': [{'lineId': 'L194',
    'characterId': 'u0',
    'movieId': 'm0',
    'character': 'BIANCA',
    'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
   {'lineId': 'L195',
    'characterId': 'u2',
    'movieId': 'm0',
    'character': 'CAMERON',
    'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
   {'lineId': 'L196',
    'characterId': 'u0',
    'movieId': 'm0',
    'character': 'BIANCA',
    'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
   {'lineId': 'L197',
    'characterId': 'u2',
    'movieId': 'm0',
    'character': 'CAMERON',
    'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]},
 {'charachterId': 'u0',
  'character2Id': 'u2',
  'movieId': 'm0',
  'uttera

In [9]:
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
             #making sure none of them is empty
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])
            
    return qa_pairs


In [10]:
#DRIVER CODE


lines=loadLines(lines_filePath)
conversations = loadConv(conv_filePath)
qa_pairs=extractSentencePairs(conversations)   

#wrinting formatted data to new file

## Define path to new file
datafile = os.path.join("dataset", "formatted_movie_lines.txt")
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)
print("\nSample lines from file:")
printLines(datafile)


Writing newly formatted file...

Sample lines from file:
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?"
b"You're asking me out.  That's so cute. What's your name again?\tForget it."
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough..."
b"Unsolved mystery.  She used to be really popular when she started high school, then it was just like she got sick of it or something.\tThat's a shame."
b'Gosh, if only we could find Kat a boyfriend...\tLet me see what I can do.'
b"That's because it's such a nice one.\tForget French."
b"How is our little Find the Wench A Date plan progressing?\tWell, there's someone I think might be --"
b'There.\tWhere?'


In [32]:
# processing the words

### Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.trimmed = False
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
            
            

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)
            

In [33]:

# Turn a Unicode string to plain ASCII,
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s



In [34]:
# Read query/response pairs and return a voc object
def readVocs(datafile):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc("dataset")
    print ("done reading")
    return voc, pairs

voc, pairs = readVocs(datafile)

Reading lines...
done reading


In [35]:
# filtering sentences 

MAX_LENGTH = 10  # Maximum sentence length to consider
def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

pairs= [pair for pair in pairs if len(pair)>1]
pairs= filterPairs(pairs)


In [36]:
#Loop through each pair and add Question and reply to vocab

for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
print ("counted words ", voc.num_words)

#print some to check
for pair in pairs[:8]:
    print(pair)

counted words  12198
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['but', 'you always been this selfish ?']
['the real you .', 'like my fear of wearing pastels ?']
['wow', 'let s go .']


In [37]:
# removing rarely used words
MIN_COUNT = 3    # Minimum word count threshold for trimming
def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT) # removes words below min count
    # remove pairs that includes these removed pairs
    keep_pairs=[]
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        
        # removed word will not be at word@index dictionary 
        
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break


        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs



# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)


keep_words 4015 / 12195 = 0.3292
Trimmed from 24337 pairs to 16872, 0.6933 of total
