In [24]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import zipfile
import pandas as pd


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [10]:
!wget https://dl.dropboxusercontent.com/s/2n5lwz8igwquo1h/movie.zip?dl=0 -O /content/data.zip

--2021-03-11 06:41:27--  https://dl.dropboxusercontent.com/s/2n5lwz8igwquo1h/movie.zip?dl=0
Resolving dl.dropboxusercontent.com (dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:6018:15::a27d:30f
Connecting to dl.dropboxusercontent.com (dl.dropboxusercontent.com)|162.125.3.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9247312 (8.8M) [application/zip]
Saving to: ‘/content/data.zip’


2021-03-11 06:41:28 (52.1 MB/s) - ‘/content/data.zip’ saved [9247312/9247312]



In [11]:
train_zip = '/content/data.zip'
train_dir = '/content/data'


with zipfile.ZipFile(train_zip, 'r') as zip_ref:
    zip_ref.extractall(train_dir)

In [31]:
def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

In [12]:
with open('/content/data/movie_lines.txt', 'rb') as corpus:
  lines = corpus.readlines()
  for line in lines[:10]:
    print(line)

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [13]:
# Split each line into a dictionary of fields (lineID, characterID, movieID, character, text)
def loadLines(path, fields):
  lines = {}
  with open(path, 'r', encoding='iso-8859-1') as f:
    for line in f:
      values = line.split(' +++$+++ ')
      lineObj = {}
      for i, field in enumerate(fields):
        lineObj[field] = values[i]
      lines[lineObj['lineID']] = lineObj
  # returns a dictionary containing the dictionaries representing each line
  # {'L871': {'lineID': ...}, 'L987': {}}
  return lines

In [17]:
# movie_conversations.txt groups lines into conversations, using the ID of these lines 
# I.e (Conversation 1: [L453, L465, L436])
# Group lines into their respective conversations
def loadConversations(path, lines, fields):
  conversations = []
  with open(path, 'r', encoding='iso-8859-1') as f:
    for line in f:
      values = line.split(' +++$+++ ')
      convObj = {}
      for i, field in enumerate(fields):
        convObj[field] = values[i]
      # Use this pattern to find all strings that contain this pattern of 'L<Numbers>'
      utterance_id_pattern = re.compile('L[0-9]+')
      lineIds = utterance_id_pattern.findall(convObj['utteranceIDs'])
      convObj['lines'] = []
      for lineId in lineIds:
        convObj['lines'].append(lines[lineId])
      conversations.append(convObj)
  return conversations

In [18]:
# Create q_a pairs using the conversation list. 
def extractSentencePairs(conversations):
  qa_pairs = []
  for conversation in conversations:
    for i in range(len(conversation['lines']) - 1):
      inputLine = conversation['lines'][i]['text'].strip()
      targetLine = conversation['lines'][i+1]['text'].strip()
      # Ensures that the lines are not empty
      if inputLine and targetLine:
        qa_pairs.append([inputLine, targetLine])
    
  return qa_pairs

In [19]:
dpath = '/content/data/'
lines = {}
conversations = {}
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

lines = loadLines(dpath + 'movie_lines.txt', MOVIE_LINES_FIELDS)
conversations = loadConversations(dpath+'movie_conversations.txt', lines, MOVIE_CONVERSATIONS_FIELDS)

In [30]:
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

with open(dpath + 'formatted_data.txt', 'w', encoding='utf-8') as f:
  writer = csv.writer(f, delimiter=delimiter, lineterminator='\n')
  for pair in extractSentencePairs(conversations):
    writer.writerow(pair)

In [32]:
printLines(dpath+'formatted_data.txt')

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\n"
b'Why?\tUnsolved mystery.  She used t

In [42]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

class Voc:
  def __init__(self, name):
    self.name = name
    self.trimmed = False
    self.word2index = {}
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
    self.num_words = 3

  def addSentence(self, sentence):
    for word in sentence.split(' '):
      self.addWord(word)
  
  def addWord(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.num_words
      self.word2count[word] = 1
      self.index2word[self.num_words] = word
      self.num_words += 1
    else:
      self.word2count[word] += 1
  
  def trim(self, min_count):
    if self.trimmed:
      return
    self.trimmed = True
    keep_words = []

    for word, occurance in self.word2count.items():
      if occurance >= min_count:
        keep_words.append(word)
    
    print(f'Keep_words {len(keep_words)}, {len(keep_words)/len(self.word2index)}%')

    # Reset all the parameters. Will have to do addWord again to populate the dictionaries
    self.word2index = []
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
    self.num_words = 3

    for word in keep_words:
      self.addWord(word)

In [51]:
# Maximum sentence length to consider. Any longer sentences will not be considered
MAX_LENGTH = 10

def unicodetoAscii(s):
  return ''.join(
    c for c in unicodedata.normalize('NFD', s)
    if unicodedata.category(c) != 'Mn'
  )

# Lowercase, trim and remove non-letter characters
def normalizeString(s):
  s = unicodetoAscii(s.lower().strip())
  s = re.sub(r"([.!?])", r" \1", s)
  s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
  s = re.sub(r"\s+", r" ", s).strip()
  return s


def readVocs(datafile, corpus_name):
  lines = open(datafile, encoding='utf-8').read().strip().split('\n')
  pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
  voc = Voc(corpus_name)
  return voc, pairs


# Returns True if both sentences in a pair are under the MAX_LENGTH threshold
def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH


def filterPairs(pairs):
  return [pair for pair in pairs if filterPair(pair)]


def loadPrepareData(corpus, corpus_name, datafile, save_dir):
  voc, pairs = readVocs(datafile, corpus_name)
  print(f'Read {len(pairs)} sentence pairs')
  pairs = filterPairs(pairs)
  print(f'Trimmed to {len(pairs)} sentence pairs')
  for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
  print(f'{voc.num_words} in total')
  return voc, pairs


save_dir = '/content/data/save/'
corpus_name = 'movie'
voc, pairs = loadPrepareData(corpus, corpus_name, dpath + 'formatted_data.txt', save_dir)
for pair in pairs[:10]:
  print(pair)

Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
18008 in total
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [52]:
MIN_COUNT = 3

def trimRareWords(voc, pairs, MIN_COUNT):
  # Remove all words whose occurance are below the min count
  voc.trim(MIN_COUNT)
  keep_pairs = []
  for pair in pairs:
    input_sentence = pair[0]
    output_sentence = pair[1]
    keep_input = True
    keep_output = True
    for word in input_sentence.split(' '):
      if word not in voc.word2index:
        keep_input = False
        break
      
    for word in output_sentence.split(' '):
      if word not in voc.word2index:
        keep_output = False
        break
    
    if keep_input and keep_output:
      # Only append if both sentences are accepted
      keep_pairs.append(pair)

  
  print(f'Trimmed from {len(pairs)} pairs to {len(keep_pairs)} pairs')

  return keep_pairs


pairs = trimRareWords(voc, pairs, MIN_COUNT)

Keep_words 7823, 0.43449041932796445%


TypeError: ignored