In [None]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from datetime import datetime as dt
import datetime
import re
from collections import Counter

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

* Preprocess the data
* Model
* Loss function

In [None]:
f = open('downton.abbey.s01e01.srt','r')
content = f.read()
group = [group.split("\n",2) for group in content.split("\n\n")]
print(group)

[['1', '00:01:10,780 --> 00:01:13,060', 'Oh, my God!'], ['2', '00:01:16,580 --> 00:01:18,580', "That's impossible."], ['3', '00:01:19,280 --> 00:01:20,580', "I'll take it up there now."], ['4', '00:01:21,080 --> 00:01:23,690', "Don't be stupid. None of\nthem will be up for hours,"], ['5', '00:01:23,700 --> 00:01:25,340', 'and what difference will it make?'], ['6', '00:01:25,350 --> 00:01:27,360', "Jimmy'll do it when he comes in."], ['7', '00:01:41,470 --> 00:01:44,060', 'April 1912'], ['8', '00:01:45,420 --> 00:01:47,380', "Six o'clock!"], ['9', '00:01:49,220 --> 00:01:52,240', 'Thank you, Daisy... Anna?'], ['10', '00:01:56,120 --> 00:01:57,610', 'Just for once in my life,'], ['11', '00:01:57,620 --> 00:02:00,480', "I'd like to sleep\nuntil I woke up natural."], ['12', '00:02:01,380 --> 00:02:04,480', '- Is your fire still in?\n- Yes, Mrs. Patmore.'], ['13', '00:02:04,490 --> 00:02:06,630', 'Ooh, my, will wonders never cease?'], ['14', '00:02:06,660 --> 00:02:09,410', "- Have you laid

In [None]:
def dialogue_pairs(group):
  pairs = []
  for i in range(1,len(group)):
    if dt.strptime(group[i][1].split(" --> ")[0],'%H:%M:%S,%f') <= dt.strptime(group[i-1][1].split(" --> ")[1] \
                        ,'%H:%M:%S,%f') + datetime.timedelta(seconds=4):
      pairs.append([group[i-1][2],group[i][2]])
  return pairs

In [None]:
pairs = dialogue_pairs(group)

In [None]:
def clean_text(text):
    
    text = text.lower()
    
    text = re.sub(r"i'm", "i am", text)
    text = re.sub(r"he's", "he is", text)
    text = re.sub(r"she's", "she is", text)
    text = re.sub(r"it's", "it is", text)
    text = re.sub(r"that's", "that is", text)
    text = re.sub(r"what's", "that is", text)
    text = re.sub(r"where's", "where is", text)
    text = re.sub(r"how's", "how is", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can't", "cannot", text)
    text = re.sub(r"n't", " not", text)
    text = re.sub(r"n'", "ng", text)
    text = re.sub(r"'bout", "about", text)
    text = re.sub(r"'til", "until", text)
    text = re.sub(r"[-()\"#/@;:<>{}`+=~|.!?,]", "", text)
    text = text.replace("\n"," ") 
    
    return text


In [None]:
for p in pairs:
  p[0] = clean_text(p[0])
  p[1] = clean_text(p[1])

In [None]:
print(pairs)

[['oh my god', 'that is impossible'], ['that is impossible', 'i will take it up there now'], ['i will take it up there now', 'do not be stupid none of them will be up for hours'], ['do not be stupid none of them will be up for hours', 'and what difference will it make'], ['and what difference will it make', 'jimmy will do it when he comes in'], ['april 1912', "six o'clock"], ["six o'clock", 'thank you daisy anna'], ['thank you daisy anna', 'just for once in my life'], ['just for once in my life', 'i would like to sleep until i woke up natural'], ['i would like to sleep until i woke up natural', ' is your fire still in  yes mrs patmore'], [' is your fire still in  yes mrs patmore', 'ooh my will wonders never cease'], ['ooh my will wonders never cease', " have you laid the servants' hall breakfast  yes mrs patmore"], [" have you laid the servants' hall breakfast  yes mrs patmore", ' and finished blacking that stove  yes mrs patmore'], [' and finished blacking that stove  yes mrs patmore'

In [None]:
queries = [p[0] for p in pairs]
responses = [p[1] for p in pairs]
query_vocab = Counter(queries)
answer_vocab = Counter(responses)

In [None]:
THRESHOLD = 3
MAX_SENTENCE_LENGTH = 10

In [None]:
queries = [p for p in queries if len(p) < MAX_SENTENCE_LENGTH]
responses = [p for p in responses if len(p) < MAX_SENTENCE_LENGTH]

In [None]:
query_word_to_int = {}

word_int_label = 0
for word in query_vocab:
  if query_vocab[word] < THRESHOLD:
    continue

  else:
    query_word_to_int[word] = word_int_label
    word_int_label += 1

response_word_to_int = {}

word_int_label = 0
for word in answer_vocab:
  if answer_vocab[word] < THRESHOLD:
    continue

  else:
    response_word_to_int[word] = word_int_label
    word_int_label += 1
  
# EOS_token = 2  # End-of-sentence token

# class Voc:

In [None]:
spec_tokens = ['<PAD>','<SOS>','<EOS>','<UNK>']

for token in spec_tokens:
  query_word_to_int[token] = len(query_word_to_int)+1

for token in spec_tokens:
  response_word_to_int[token] = len(response_word_to_int)+1

In [None]:
# Inverse dictionary
query_int_to_word = {v:k for k,v in query_word_to_int.items()}
responses_int_to_word = {v:k for k,v in response_word_to_int.items()}

In [None]:
for q in queries:
  q = q + ' <EOS>'

for r in responses:
  r = r + ' <EOS>'

In [None]:
query_ints = []
for q in queries:
  temp = []
  for word in q:
    try:
      temp.append(query_word_to_int[word])
    except:
      temp.append(query_word_to_int['<UNK>'])
    query_ints.append(temp)

response_ints = []
for q in responses:
  temp = []
  for word in q:
    try:
      temp.append(query_word_to_int[word])
    except:
      temp.append(query_word_to_int['<UNK>'])
    response_ints.append(temp)