In [None]:
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import json
from google.colab import drive
drive.mount('/content/drive')

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

Mounted at /content/drive


In [None]:
corpus_name = "movie-corpus"
corpus = os.path.join("/content/drive/My Drive/", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "utterances.jsonl"))

b'{"id": "L1045", "conversation_id": "L1044", "text": "They do not!", "speaker": "u0", "meta": {"movie_id": "m0", "parsed": [{"rt": 1, "toks": [{"tok": "They", "tag": "PRP", "dep": "nsubj", "up": 1, "dn": []}, {"tok": "do", "tag": "VBP", "dep": "ROOT", "dn": [0, 2, 3]}, {"tok": "not", "tag": "RB", "dep": "neg", "up": 1, "dn": []}, {"tok": "!", "tag": ".", "dep": "punct", "up": 1, "dn": []}]}]}, "reply-to": "L1044", "timestamp": null, "vectors": []}\n'
b'{"id": "L1044", "conversation_id": "L1044", "text": "They do to!", "speaker": "u2", "meta": {"movie_id": "m0", "parsed": [{"rt": 1, "toks": [{"tok": "They", "tag": "PRP", "dep": "nsubj", "up": 1, "dn": []}, {"tok": "do", "tag": "VBP", "dep": "ROOT", "dn": [0, 2, 3]}, {"tok": "to", "tag": "TO", "dep": "dobj", "up": 1, "dn": []}, {"tok": "!", "tag": ".", "dep": "punct", "up": 1, "dn": []}]}]}, "reply-to": null, "timestamp": null, "vectors": []}\n'
b'{"id": "L985", "conversation_id": "L984", "text": "I hope so.", "speaker": "u0", "meta": {

In [None]:
def loadLinesAndConversations(fileName):
  lines = {}
  conversations = {}
  with open(fileName, "r", encoding="utf-8") as f:
    for line in f:
      lineJson = json.loads(line)
      lineObj = {}
      lineObj["LineID"] = lineJson["id"]
      lineObj["characterID"] = lineJson["speaker"]
      lineObj["text"] = lineJson["text"]
      lines[lineObj["LineID"]] = lineObj

      if lineJson["conversation_id"] not in conversations:
        convObj = {}
        convObj["conversationID"] = lineJson["conversation_id"]
        convObj["movieID"] = lineJson["meta"]["movie_id"]
        convObj["lines"] = [lineObj]
      else:
          convObj= conversations[lineJson["conversation_id"]]
          convObj["lines"].insert(0, lineObj)
      conversations[convObj["conversationID"]] = convObj
    return lines, conversations

def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations.values():
      for i in range(len(conversation["lines"]) - 1):
        inputLine = conversation["lines"][i]["text"].strip()
        targetLine = conversation["lines"][i+1]["text"].strip()
        if inputLine and targetLine:
          qa_pairs.append([inputLine, targetLine])
    return qa_pairs



In [None]:
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'

delimiter = str(codecs.decode(delimiter, "unicode_escape"))

lines = {}
conversations = {}

print("\nProcessing corpus into lines and conversations...")
lines, conversations = loadLinesAndConversations(os.path.join(corpus, "utterances.jsonl"))

print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)


Processing corpus into lines and conversations...

Writing newly formatted file...


In [None]:
print("\nSample lines from file:")
printLines(datafile)


Sample lines from file:
b'They do to!\tThey do not!\n'
b'She okay?\tI hope so.\n'
b"Wow\tLet's go.\n"
b'"I\'m kidding.  You know how sometimes you just become this ""persona""?  And you don\'t know how to quit?"\tNo\n'
b"No\tOkay -- you're gonna need to learn how to lie.\n"
b"I figured you'd get to the good stuff eventually.\tWhat good stuff?\n"
b'What good stuff?\t"The ""real you""."\n'
b'"The ""real you""."\tLike my fear of wearing pastels?\n'
b'do you listen to this crap?\tWhat crap?\n'
b"What crap?\tMe.  This endless ...blonde babble. I'm like, boring myself.\n"


In [None]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

class Voc:
  def __init__(self, name):
    self.name = name
    self.trimmed = False
    self.word2index = {}
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
    self.num_words = 3

  def addSentence(self, sentence):
    for word in sentence.split(" "):
      self.addWord(word)

  def addWord(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.num_words
      self.word2count[word] = 1
      self.index2word[self.num_words] = word
      self.num_words += 1
    else:
      self.word2count[word] += 1

  def trim(self, min_count):
    if self.trimmed:
      return
    self.trimmed = True

    keep_words = []

    for k, v in self.word2count.items():
      if v >= min_count:
        keep_words.append(k)

    print("keep_words {} / {} = {:.4f}".format(len(keep_words),
                                               len(self.word2index),
                                               len(keep_words) / len(self.word2index)))

    self.word2index = {}
    self.word2count = {}
    self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
    self.num_words = 3

    for word in keep_words:
      self.addWord(word)

In [None]:
MAX_LENGTH = 10

def unicodeToAscii(s):
  return "".join(
      c for c in unicodedata.normalize("NFD", s)
      if unicodedata.category(c) != "Mn"
  )

def normalizeString(s):
  s = unicodeToAscii(s.lower().strip())
  s = re.sub(r"([.!?])", r" \1", s)
  s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
  s = re.sub(r"\s+", r" ", s).strip()
  return s

def readVocs(datafile, corpus_name):
  print("Reading lines...")
  lines = open(datafile, encoding="utf-8").read().strip().split("\n")
  pairs = [[normalizeString(s) for s in l.split("\t")] for l in lines]
  voc = Voc(corpus_name)
  return voc, pairs

def filterPair(p):
  return len(p[0].split(" ")) < MAX_LENGTH and len(p[1].split(" ")) < MAX_LENGTH

def filterPairs(pairs):
  return [pair for pair in pairs if filterPair(pair)]

def loadPreparedData(corpus, corpus_name, datafile, save_dir):
  print("Start preparing training data...")
  voc, pairs = readVocs(datafile, corpus_name)
  print("Read {!s} sentence pairs".format(len(pairs)))
  pairs = filterPairs(pairs)
  print("Trimmed to {!s} sentence pairs".format(len(pairs)))
  print("Counting words...")
  for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
  print("Counted words:", voc.num_words)
  return voc, pairs

In [None]:
save_dir = os.path.join("data", "save")
voc, pairs = loadPreparedData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
  print(pair)


Start preparing training data...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64313 sentence pairs
Counting words...
Counted words: 18082

pairs:
['they do to !', 'they do not !']
['she okay ?', 'i hope so .']
['wow', 'let s go .']
['what good stuff ?', 'the real you .']
['the real you .', 'like my fear of wearing pastels ?']
['do you listen to this crap ?', 'what crap ?']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['have fun tonight ?', 'tons']


In [None]:
MIN_COUNT = 3

def trimRareWords(voc, pairs, MIN_COUNT):
  voc.trim(MIN_COUNT)
  keep_pairs = []
  for pair in pairs:
    input_sentence = pair[0]
    output_sentence = pair[1]
    keep_input = True
    keep_output = True
    for word in input_sentence.split(" "):
      if word not in voc.word2index:
        keep_input = False
        break
    for word in output_sentence.split(" "):
      if word not in voc.word2index:
        keep_output = False
        break
    if keep_input and keep_output:
      keep_pairs.append(pair)
  print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs),
                                                              len(keep_pairs),
                                                              len(keep_pairs) / len(pairs)))
  return keep_pairs

pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7833 / 18079 = 0.4333
Trimmed from 64313 pairs to 53131, 0.8261 of total
