# Build a chatbot using a *seq2seq* model

# Stage 0. Download the data

In [1]:
FILE = 'formatted_movie_lines.txt'

In [2]:
def print_lines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)
        print('--')

print_lines(FILE, n=3)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
--
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
--
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
--


## Stage 1. Pre-process the dataset

### Vocabulary object

In [3]:
class Vocab:
    
    PAD_token = 0  # Used for padding short sentences
    SOS_token = 1  # Start-of-sentence token
    EOS_token = 2  # End-of-sentence token
    
    def __init__(self):
        
        # self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {self.PAD_token: "PAD", self.SOS_token: "SOS", self.EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD
    
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)
    
    def add_word(self, word):
        if word in self.word2index:
            # this word is already in the vocabulary
            self.word2count[word] += 1
        else:
            # this a new word
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        
    def trim(self, min_count):
        """Removes words from the vocabulary that appear less
        than min_count times in the corpus used to build the
        vocabulary"""
        if self.trimmed:
            # already trimmed, nothing to do
            return
        self.trimmed = True
        
        keep_words = []
        for word, count in self.word2count.items():
            if count >= min_count:
                keep_words.append(word)
        
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))
        
        # generate the vocabulary again, using only 'keep_words'
        self.word2index = {}
        self.word2count = {}
        self.index2word = {self.PAD_token: "PAD", self.SOS_token: "SOS", self.EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD
        for word in keep_words:
            self.add_word(word)
            
#             print(self.word2count)
#             import pdb
#             pdb.set_trace()
            
    def has_word(self, word):
        return True if word in self.word2index else False
    
    def __str__(self):
        return 'Num words: {}'.format(self.num_words)

### Pre-processing

In [4]:
from tqdm import tqdm
import re

def load_pairs(file, n=None):
    lines = open(file, encoding='utf-8').read().strip().split('\n')
    pairs = []
    for line in tqdm(lines):
        pair = line.split('\t')
        pairs.append(pair)
        
    return pairs

pairs = load_pairs(FILE)
print('Original set size: ', len(pairs))

import unicodedata
def unicode2Ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )
def normalize_pairs(pairs):
    def normalize_string(s):
        s = unicode2Ascii(s.lower().strip())
        s = re.sub(r"([.!?])", r" \1", s)
        s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
        s = re.sub(r"\s+", r" ", s).strip()
        return s
    
    pairs = [[normalize_string(p[0]), normalize_string(p[1])] for p in pairs]
    return pairs

pairs = normalize_pairs(pairs)
print('After normalization: ', len(pairs))

def filter_pairs_by_length(pairs, max_length: int = 999999):
    
    def is_pair_short_enough(p):
        return (len(p[0].split(' ')) < max_length) and (len(p[1].split(' ')) < max_length)
    
    return [pair for pair in pairs if is_pair_short_enough(pair)]

MAX_LENGTH = 10
pairs = filter_pairs_by_length(pairs, max_length=MAX_LENGTH)
print('After filtering:  ', len(pairs))

100%|██████████| 221282/221282 [00:00<00:00, 712108.12it/s]


Original set size:  221282
After normalization:  221282
After filtering:   64271


In [5]:
# generate vocabulary
def generate_vocabulary(pairs):
    
    vocab = Vocab()
    for pair in pairs:
        vocab.add_sentence(pair[0])
        vocab.add_sentence(pair[1])
    
    return vocab

MIN_WORD_COUNT = 3
vocab = generate_vocabulary(pairs)
print(vocab)

Num words: 18008


In [6]:
def filter_pairs_with_rare_words(pairs, vocab, min_count):
    vocab.trim(min_count)
    print(vocab)
    keep_pairs = []
    for pair in pairs:
        keep_pair = True
        
        # check if all words from 1st sentence are NOT rare
        for word in pair[0].split(' '):
            if not vocab.has_word(word):
                keep_pair = False
                break

        if not keep_pair:
            continue
        
        for word in pair[1].split(' '):
            if not vocab.has_word(word):
                keep_pair = False
                break
        
        if keep_pair:
            keep_pairs.append(pair)
        
    print('Trimmed from {} pairs to {} pairs, {:.4f} of total'.format(
        len(pairs), len(keep_pairs), len(keep_pairs)/len(pairs)))

    return keep_pairs, vocab

pairs, vocab = filter_pairs_with_rare_words(pairs, vocab, MIN_WORD_COUNT)

keep_words 7823 / 18005 = 0.4345
Num words: 7826
Trimmed from 64271 pairs to 53165 pairs, 0.8272 of total


In [7]:
# Dataset and Dataloader
def tokenize_pairs(pairs, vocab):
    tokenized_pairs = []
    for pair in pairs:
        tokenized_sentence_0 = []
        tokenized_sentence_1 = []
        for word in pair[0].split(' '):
            tokenized_sentence_0.append(vocab.word2index[word])
        for word in pair[1].split(' '):
            tokenized_sentence_1.append(vocab.word2index[word])
        tokenized_pairs.append([tokenized_sentence_0, tokenized_sentence_1])
    
    return tokenized_pairs

tokenized_pairs = tokenize_pairs(pairs, vocab)

for i in range(0, 5):
    print(pairs[i])
    print(tokenized_pairs[i])
    print('\n')

['there .', 'where ?']
[[3, 4], [5, 6]]


['you have my word . as a gentleman', 'you re sweet .']
[[7, 8, 9, 10, 4, 11, 12, 13], [7, 14, 15, 4]]


['hi .', 'looks like things worked out tonight huh ?']
[[16, 4], [17, 18, 19, 20, 21, 22, 23, 6]]


['have fun tonight ?', 'tons']
[[8, 31, 22, 6], [32]]


['well no . . .', 'then that s all you had to say .']
[[33, 34, 4, 4, 4], [35, 36, 37, 38, 7, 39, 40, 41, 4]]




In [None]:
import torch.nn as nn

class Encoder(nn.Module):
    
    def __init__(self):
        pass
    
    def forward(self, )