In [1]:
from collections import Counter
import torch
import json
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import math
import torch.nn.functional as F
import pickle

# Data set preparation

In [2]:
corpus_movie_conv=r'/mnt/disk1/Gulshan/rnn/movie_conversations.txt'
corpus_movie_lines=r'/mnt/disk1/Gulshan/rnn/movie_lines.txt'
MAX_LENGTH=25

In [3]:
with open(corpus_movie_conv,'r',encoding='iso-8859-1') as f:
    conversations=f.readlines()
with open(corpus_movie_lines,'r',encoding='iso-8859-1') as f:
    conv_lines=f.readlines()

In [4]:
conversations[0]

"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']\n"

In [5]:
conv_lines[0]

'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'

In [6]:
lines_dict={}
for line in conv_lines:
    objects=line.split(' +++$+++ ')
    lines_dict[objects[0]]=objects[-1]

In [7]:
lines_dict['L1045']

'They do not!\n'

In [8]:
def remove_punc(string):
    punctuations='''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punc=""
    for char in string:
        if char not in punctuations:
            no_punc=no_punc+char
    return no_punc.lower()

In [9]:
remove_punc('They do not!\n')

'they do not\n'

In [10]:
conversations[0].split('+++$+++')[-1]
eval(conversations[0].split('+++$+++')[-1])

['L194', 'L195', 'L196', 'L197']

In [11]:
pairs=[]
for conv in conversations:
     ids=eval(conv.split('+++$+++')[-1])
     for i in range(len(ids)):
         qa_pairs=[]
         if i==len(ids)-1:
             break
         question=remove_punc(lines_dict[ids[i]].strip())
         anwser=remove_punc(lines_dict[ids[i+1]].strip())
         
         qa_pairs.append(question.split()[:MAX_LENGTH])
         qa_pairs.append(anwser.split()[:MAX_LENGTH])
         
         pairs.append(qa_pairs)

In [12]:
pairs[0]

[['can',
  'we',
  'make',
  'this',
  'quick',
  'roxanne',
  'korrine',
  'and',
  'andrew',
  'barrett',
  'are',
  'having',
  'an',
  'incredibly',
  'horrendous',
  'public',
  'break',
  'up',
  'on',
  'the',
  'quad',
  'again'],
 ['well',
  'i',
  'thought',
  'wed',
  'start',
  'with',
  'pronunciation',
  'if',
  'thats',
  'okay',
  'with',
  'you']]

In [13]:
# with open("pairs", "wb") as fp:   #Pickling
#     pickle.dump(pairs, fp)

In [14]:
word_freq=Counter()
for pair in pairs:
    word_freq.update(pair[0]) # only update unique words
    word_freq.update(pair[1]) # only update unique words

In [15]:
min_word_freq=5 # mininum no of words should occur
words=[w for w in word_freq.keys() if word_freq[w]>min_word_freq]
print(words[0])
word_map={k:v+1 for v,k in enumerate(words)}

can


In [16]:
print(len(word_map))
word_map['<Start>']=len(word_map)+1
word_map['<END>']=len(word_map)+1
word_map['<PAD>']=0
word_map['<unkown>']=len(word_map)+1
len(word_map)

18239


18243

In [17]:
# with open('WORDMAP_corpus.json','w') as f:
#     json.dump(word_map,f)

In [18]:
# with open('/mnt/disk1/Gulshan/rnn/transformer/WORDMAP_corpus.json','r') as f:
#     word_map=json.load(f)
# with open("pairs", "rb") as fp:   #Pickling
#     pairs= pickle.load(fp)

In [19]:
def encode_questions(words,word_map):
    enc_q=[word_map.get(word,word_map['<unkown>']) for word in words] + [word_map['<PAD>']]*(MAX_LENGTH-len(words)) # padding
    return enc_q
def encode_answer(words,word_map):
    enc_q=[word_map['<Start>']]+[word_map.get(word,word_map['<unkown>']) for word in words] +[word_map['<END>']]+[word_map['<PAD>']]*(MAX_LENGTH-len(words)) # padding
    return enc_q

In [20]:
encode_questions(pairs[0][0],word_map)


[1,
 2,
 3,
 4,
 5,
 18243,
 18243,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 18243,
 13,
 14,
 15,
 16,
 17,
 18243,
 18,
 0,
 0,
 0]

In [21]:
pairs_encoded=[]
for pair in pairs:
    ques=encode_questions(pair[0],word_map)
    anws=encode_answer(pair[1],word_map)
    pairs_encoded.append([ques,anws])

In [22]:
pairs_encoded[0]

[[1,
  2,
  3,
  4,
  5,
  18243,
  18243,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  18243,
  13,
  14,
  15,
  16,
  17,
  18243,
  18,
  0,
  0,
  0],
 [18240,
  19,
  20,
  21,
  22,
  23,
  24,
  18243,
  25,
  26,
  27,
  24,
  28,
  18241,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]]

In [23]:
# with open('pairs_encoded.json','w') as f:
#     json.dump(pairs_encoded,f)a
