# A chatbot based on seq2seq model with attention

This chatbot follows the encoder-decoder approach, using [Bahdanau attention](https://arxiv.org/pdf/1409.0473.pdf).
* A GRU encoder maps a tokenized English sentence to a number of hidden states
* A GRU decoder then uses its own state and all encoder states to compute a context vector.
* During training, (at each time step) the decoder takes the context vector and the target output to predict an output sequence; during testing, the decoder takes the context vector and the previous prediction to compute the current prediction.

The inference experiments showed reasonably good results.

In [None]:
import re
import os
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
from sklearn.utils import shuffle

CONV_TSV_PATH = 'conv_pair.tsv'
TOKENIZER_PATH = 'tokenizer.pkl'

Prepare the dataset
I use [Cornell Movie Dialogs](http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) as the training corpus.  
Prepation of the dataset uses some code from this [tensorflow official blog article](https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/transformer_chatbot.ipynb).  
But I used more sophisticated text cleaning heuristics. For example, the `html` marker `<u>...</u>` appears frequently in the raw text and should be removed.

In [None]:
path_to_zip = tf.keras.utils.get_file(
    'cornell_movie_dialogs.zip',
    origin='http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip',
    extract=True)

path_to_dataset = os.path.join(os.path.dirname(path_to_zip), 
                               'cornell movie-dialogs corpus')
path_to_movie_lines = os.path.join(path_to_dataset, 'movie_lines.txt')
path_to_movie_conversations = os.path.join(path_to_dataset, 'movie_conversations.txt')

Downloading data from http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip


In [None]:
def text_clean(s):
    s = s.lower()
    s = re.sub(r'<u>|</u>', '', s)
    s = re.sub(r'[/(){}\[\]\|@]', ' ', s)
    s = re.sub(r'[^0-9a-z ,.!?]', ' ', s)
    s = re.sub(r'([,.!?])', r' \1 ', s)
    s = re.sub(r'\s+', ' ', s)
    return s.strip()

def load_conversations(MAX_SAMPLES=20000):
    id2line = {}
    with open(path_to_movie_lines, errors='ignore') as fin:
        for line in fin:
            parts = line.strip().split(' +++$+++ ')
            if len(parts) < 5: parts.append(' ')
            id2line[parts[0]] = parts[4]

    ques, ans = [], []
    with open(path_to_movie_conversations) as fin:
        for line in fin:
            parts = line.replace('\n', '').split(' +++$+++ ')
            conversation = [line[1:-1] for line in parts[3][1:-1].split(', ')]

            for i in range(len(conversation) - 1):
                q_text = text_clean(id2line[conversation[i]])
                a_text = text_clean(id2line[conversation[i + 1]])
                if q_text and a_text:
                    ques.append(q_text)
                    ans.append(a_text)
                if len(ques) >= MAX_SAMPLES: break
            
            if len(ques) >= MAX_SAMPLES: break
  
    return pd.DataFrame(list(zip(ques, ans)), 
                        columns=['ques', 'ans'])

## Tokenization
I saved the lists of questions and answers in a `tsv` file for later use.  
I used **tensorflow's** `SubwordTextEncoder` tokenizer which has a byte-level fallback. The tokenizer is saved for later use too.

In [None]:
# create the dataframe
df = load_conversations()
df.to_csv(CONV_TSV_PATH, sep='\t', index=False)

# build vocabulary & tokenzier from examples
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
    df['ques'] + df['ans'], target_vocab_size=2**13)
START_TOK = [tokenizer.vocab_size]
END_TOK = [tokenizer.vocab_size + 1]

pickle.dump((tokenizer, START_TOK, END_TOK), 
            open(TOKENIZER_PATH, 'wb'))

In [None]:
# load the dataset & the tokenizer
df = pd.read_csv(CONV_TSV_PATH, sep='\t')
tokenizer, START_TOK, END_TOK = pickle.load(open(TOKENIZER_PATH, 'rb'))
VOCAB_SIZE = tokenizer.vocab_size + 2

In [None]:
def texts2ids(texts, tokenizer, max_len=40, doclean=False):
    '''list of texts -> list of token ID arrays'''
    tokenized = []
    for t in texts:
        if doclean: t = text_clean(t)
        sent = START_TOK + tokenizer.encode(t) + END_TOK
        if len(sent) > max_len:
            sent = sent[:max_len - 1] + END_TOK
        tokenized.append(sent)
        
    return tf.keras.preprocessing.sequence.pad_sequences(
        tokenized, maxlen=max_len, padding='post')


I use only 20000 question-answer pairs. Now they are converted to token ID lists.

In [None]:
# prepare tokenized questions and answers
SAMPLE_SIZE = 20000
X, y = shuffle(df['ques'], df['ans'], 
               random_state=0,
               n_samples=SAMPLE_SIZE)
X, y = texts2ids(X, tokenizer), texts2ids(y, tokenizer)


This is what the dataset, when de-tokenized, looks like

In [None]:
for X1, y1 in zip(X[:5], y[:5]):
    print('Q:', tokenizer.decode([t for t in X1 if t < tokenizer.vocab_size]))
    print('A:', tokenizer.decode([t for t in y1 if t < tokenizer.vocab_size]))

Q: right . farewell . . . for the last time . . . may the gods prevent . . .
A: no , don t say anything else !
Q: the end of the recordare statuens in parte dextra .
A: so now the confutatis . confutatis maledictis . when the wicked are confounded . flammis acribus addictis . how would you translate 
Q: remember the guy who cheated at the table ?
A: you don t like cheats , do you .
Q: you can t have sven s father sitting next to sven . they ll argue the whole time .
A: that s true . you d better sit there . you there , and ornulf there .
Q: the bedroom .
A: there is only one bed .


## Encoder-decoder with attention
In defining the model architecture, I use code from this [tensorflow tutorial](https://www.tensorflow.org/tutorials/text/nmt_with_attention)

In [None]:
# architecture

class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, emb_size, enc_units,
                       batch_size, dropout):
        super().__init__()
        self.batch_size = batch_size
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, emb_size)
        self.gru = tf.keras.layers.GRU(
            self.enc_units,
            return_sequences=True,
            return_state=True,
            dropout=dropout,
            recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state=hidden)
        return output, state
    
    def initial_hidden_state(self):
        return tf.zeros((self.batch_size, self.enc_units))


class BAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 1)

        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis) + self.W2(values)))

        weights = tf.nn.softmax(score, axis=1)
        context = weights * values
        context = tf.reduce_sum(context, axis=1)

        return context, weights


class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, emb_size, dec_units,
                       batch_size, dropout):
        super().__init__()
        self.batch_size = batch_size
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, emb_size)
        self.gru = tf.keras.layers.GRU(
            self.dec_units,
            return_sequences=True,
            return_state=True,
            dropout=dropout,
            recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)
        self.attention = BAttention(dec_units)

    def call(self, x, hidden, enc_output):
        x = self.embedding(x)

        context, weights = self.attention(hidden, enc_output)
        x = tf.concat([tf.expand_dims(context, 1), x], axis=-1)
        
        output, state = self.gru(x)
        output = tf.reshape(output, (-1, output.shape[2]))
        x = self.fc(output)

        return x, state, weights

## Batch generator, hyper-parameters, optimizer, ...  
I use some code from the previous tensorflow tutorial

In [None]:
BATCH_SIZE = 64
STEPS_PER_EPOCH = len(X) // BATCH_SIZE

dataset = tf.data.Dataset.from_tensor_slices((
    {
        'inputs': X,
        'dec_inputs': y[:, :-1]
    },
    {
        'targets': y[:, 1:]
    }
))

dataset = dataset.cache()
dataset = dataset.batch(BATCH_SIZE)

In [None]:
EMBEDDING_SIZE = 256
HIDDEN_UNITS = 256
DROPOUT_RT = 0.5

In [None]:
# create encoder, decoder
encoder = Encoder(VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_UNITS, 
                  BATCH_SIZE, DROPOUT_RT)
decoder = Decoder(VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_UNITS, 
                  BATCH_SIZE, DROPOUT_RT)

# Optimizer and the loss
optimizer = tf.keras.optimizers.Adam()
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')


def loss_func(real, pred):
    loss = loss_obj(real, pred)
    mask = tf.cast(tf.math.not_equal(real, 0), dtype=loss.dtype)
    return tf.reduce_mean(loss * mask)


@tf.function
def train_step(inputs, dec_inputs, targets, enc_hidden):
    loss = 0

    with tf.GradientTape() as tape:
        enc_output, enc_hidden = encoder(inputs, enc_hidden)
        dec_hidden = enc_hidden

        for i in range(targets.shape[1]):
            pred, dec_hidden, _ = decoder(
                tf.expand_dims(dec_inputs[:, i], 1),
                dec_hidden,
                enc_output)
            
            loss += loss_func(targets[:, i], pred)

    batch_loss = loss / int(targets.shape[1])
    trainable = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, trainable)
    optimizer.apply_gradients(zip(gradients, trainable))

    return batch_loss

## Details of the training experiment
I trained the model for 40 epochs in two passes.  
The training loss of the 0th epoch was **2.12**. The figure dropped to **0.77** at the 39th epochs.  
The logging of training is a bit long ...  
The weights of the models are saved for later use.

In [None]:
EPOCHS = 20

for epoch in range(EPOCHS):
    enc_hidden = encoder.initial_hidden_state()
    total_loss = 0

    for n, batch in tqdm(enumerate(dataset.take(STEPS_PER_EPOCH))):
        batch_loss = train_step(
            batch[0]['inputs'], 
            batch[0]['dec_inputs'], 
            batch[1]['targets'],
            enc_hidden)

        total_loss += batch_loss
        if n % 50 == 0:
            print(' Epoch %s Batch %s Loss %.4f' % (
                epoch, n, batch_loss.numpy()))
    
    print(' Epoch %s Loss %.4f\n' % (
        epoch, total_loss / STEPS_PER_EPOCH))


encoder.save_weights('enc_weights.h5')
decoder.save_weights('dec_weights.h5')

2it [00:31, 21.94s/it]

 Epoch 0 Batch 0 Loss 3.2251


52it [00:37,  7.99it/s]

 Epoch 0 Batch 50 Loss 2.2151


102it [00:43,  7.92it/s]

 Epoch 0 Batch 100 Loss 2.1213


152it [00:50,  7.99it/s]

 Epoch 0 Batch 150 Loss 2.0749


202it [00:56,  8.29it/s]

 Epoch 0 Batch 200 Loss 2.0497


252it [01:02,  7.97it/s]

 Epoch 0 Batch 250 Loss 1.8675


302it [01:08,  8.19it/s]

 Epoch 0 Batch 300 Loss 1.9453


312it [01:09,  4.47it/s]
1it [00:00,  7.81it/s]

 Epoch 0 Loss 2.1228

 Epoch 1 Batch 0 Loss 1.8805


52it [00:06,  8.09it/s]

 Epoch 1 Batch 50 Loss 1.8210


102it [00:12,  7.70it/s]

 Epoch 1 Batch 100 Loss 1.8199


152it [00:18,  7.95it/s]

 Epoch 1 Batch 150 Loss 1.8548


202it [00:25,  8.17it/s]

 Epoch 1 Batch 200 Loss 1.8554


252it [00:31,  8.11it/s]

 Epoch 1 Batch 250 Loss 1.7100


302it [00:37,  8.01it/s]

 Epoch 1 Batch 300 Loss 1.8034


312it [00:38,  8.03it/s]
1it [00:00,  7.68it/s]

 Epoch 1 Loss 1.8345

 Epoch 2 Batch 0 Loss 1.7423


52it [00:06,  7.88it/s]

 Epoch 2 Batch 50 Loss 1.7273


102it [00:12,  8.25it/s]

 Epoch 2 Batch 100 Loss 1.7225


152it [00:18,  7.91it/s]

 Epoch 2 Batch 150 Loss 1.7730


202it [00:25,  8.16it/s]

 Epoch 2 Batch 200 Loss 1.7691


252it [00:31,  8.11it/s]

 Epoch 2 Batch 250 Loss 1.6391


302it [00:37,  7.99it/s]

 Epoch 2 Batch 300 Loss 1.7339


312it [00:38,  8.02it/s]
1it [00:00,  8.58it/s]

 Epoch 2 Loss 1.7451

 Epoch 3 Batch 0 Loss 1.6770


52it [00:06,  8.13it/s]

 Epoch 3 Batch 50 Loss 1.6699


102it [00:12,  7.98it/s]

 Epoch 3 Batch 100 Loss 1.6609


152it [00:18,  7.97it/s]

 Epoch 3 Batch 150 Loss 1.7221


202it [00:25,  8.07it/s]

 Epoch 3 Batch 200 Loss 1.7157


252it [00:31,  8.05it/s]

 Epoch 3 Batch 250 Loss 1.5911


302it [00:37,  7.93it/s]

 Epoch 3 Batch 300 Loss 1.6898


312it [00:38,  8.06it/s]
1it [00:00,  7.80it/s]

 Epoch 3 Loss 1.6904

 Epoch 4 Batch 0 Loss 1.6350


52it [00:06,  8.17it/s]

 Epoch 4 Batch 50 Loss 1.6275


102it [00:12,  8.13it/s]

 Epoch 4 Batch 100 Loss 1.6172


152it [00:18,  8.09it/s]

 Epoch 4 Batch 150 Loss 1.6871


202it [00:25,  7.88it/s]

 Epoch 4 Batch 200 Loss 1.6760


252it [00:31,  7.91it/s]

 Epoch 4 Batch 250 Loss 1.5542


302it [00:37,  8.27it/s]

 Epoch 4 Batch 300 Loss 1.6561


312it [00:38,  8.04it/s]
1it [00:00,  7.32it/s]

 Epoch 4 Loss 1.6504

 Epoch 5 Batch 0 Loss 1.5977


52it [00:06,  7.84it/s]

 Epoch 5 Batch 50 Loss 1.5930


102it [00:12,  8.08it/s]

 Epoch 5 Batch 100 Loss 1.5860


152it [00:18,  8.13it/s]

 Epoch 5 Batch 150 Loss 1.6503


202it [00:25,  8.04it/s]

 Epoch 5 Batch 200 Loss 1.6387


252it [00:31,  8.13it/s]

 Epoch 5 Batch 250 Loss 1.5206


302it [00:37,  8.00it/s]

 Epoch 5 Batch 300 Loss 1.6236


312it [00:38,  8.06it/s]
1it [00:00,  8.08it/s]

 Epoch 5 Loss 1.6169

 Epoch 6 Batch 0 Loss 1.5677


52it [00:06,  7.77it/s]

 Epoch 6 Batch 50 Loss 1.5663


102it [00:12,  7.92it/s]

 Epoch 6 Batch 100 Loss 1.5494


152it [00:19,  8.07it/s]

 Epoch 6 Batch 150 Loss 1.6170


202it [00:25,  8.09it/s]

 Epoch 6 Batch 200 Loss 1.6024


252it [00:31,  8.17it/s]

 Epoch 6 Batch 250 Loss 1.4864


302it [00:37,  8.02it/s]

 Epoch 6 Batch 300 Loss 1.5970


312it [00:39,  7.98it/s]
1it [00:00,  8.24it/s]

 Epoch 6 Loss 1.5855

 Epoch 7 Batch 0 Loss 1.5381


52it [00:06,  8.01it/s]

 Epoch 7 Batch 50 Loss 1.5335


102it [00:12,  8.16it/s]

 Epoch 7 Batch 100 Loss 1.5164


152it [00:18,  7.87it/s]

 Epoch 7 Batch 150 Loss 1.5808


202it [00:25,  7.85it/s]

 Epoch 7 Batch 200 Loss 1.5665


252it [00:31,  8.08it/s]

 Epoch 7 Batch 250 Loss 1.4561


302it [00:37,  8.03it/s]

 Epoch 7 Batch 300 Loss 1.5661


312it [00:38,  8.05it/s]
1it [00:00,  7.99it/s]

 Epoch 7 Loss 1.5547

 Epoch 8 Batch 0 Loss 1.5070


52it [00:06,  8.04it/s]

 Epoch 8 Batch 50 Loss 1.5000


102it [00:12,  7.81it/s]

 Epoch 8 Batch 100 Loss 1.4864


152it [00:18,  7.96it/s]

 Epoch 8 Batch 150 Loss 1.5473


202it [00:25,  8.13it/s]

 Epoch 8 Batch 200 Loss 1.5315


252it [00:31,  7.94it/s]

 Epoch 8 Batch 250 Loss 1.4256


302it [00:37,  7.88it/s]

 Epoch 8 Batch 300 Loss 1.5370


312it [00:38,  8.05it/s]
1it [00:00,  8.12it/s]

 Epoch 8 Loss 1.5246

 Epoch 9 Batch 0 Loss 1.4720


52it [00:06,  8.03it/s]

 Epoch 9 Batch 50 Loss 1.4700


102it [00:12,  8.17it/s]

 Epoch 9 Batch 100 Loss 1.4559


152it [00:18,  7.94it/s]

 Epoch 9 Batch 150 Loss 1.5134


202it [00:25,  8.09it/s]

 Epoch 9 Batch 200 Loss 1.5014


252it [00:31,  8.19it/s]

 Epoch 9 Batch 250 Loss 1.4018


302it [00:37,  8.23it/s]

 Epoch 9 Batch 300 Loss 1.5084


312it [00:38,  8.06it/s]
1it [00:00,  7.74it/s]

 Epoch 9 Loss 1.4949

 Epoch 10 Batch 0 Loss 1.4396


52it [00:06,  7.76it/s]

 Epoch 10 Batch 50 Loss 1.4396


102it [00:12,  8.03it/s]

 Epoch 10 Batch 100 Loss 1.4155


152it [00:18,  8.00it/s]

 Epoch 10 Batch 150 Loss 1.4800


202it [00:25,  8.10it/s]

 Epoch 10 Batch 200 Loss 1.4688


252it [00:31,  7.96it/s]

 Epoch 10 Batch 250 Loss 1.3696


302it [00:37,  7.94it/s]

 Epoch 10 Batch 300 Loss 1.4789


312it [00:38,  8.03it/s]
1it [00:00,  8.06it/s]

 Epoch 10 Loss 1.4644

 Epoch 11 Batch 0 Loss 1.4042


52it [00:06,  7.95it/s]

 Epoch 11 Batch 50 Loss 1.4071


102it [00:12,  7.79it/s]

 Epoch 11 Batch 100 Loss 1.3798


152it [00:19,  7.75it/s]

 Epoch 11 Batch 150 Loss 1.4426


202it [00:25,  8.05it/s]

 Epoch 11 Batch 200 Loss 1.4364


252it [00:31,  8.06it/s]

 Epoch 11 Batch 250 Loss 1.3429


302it [00:37,  8.35it/s]

 Epoch 11 Batch 300 Loss 1.4452


312it [00:39,  8.00it/s]
1it [00:00,  7.76it/s]

 Epoch 11 Loss 1.4320

 Epoch 12 Batch 0 Loss 1.3704


52it [00:06,  8.04it/s]

 Epoch 12 Batch 50 Loss 1.3768


102it [00:12,  8.12it/s]

 Epoch 12 Batch 100 Loss 1.3501


152it [00:19,  7.73it/s]

 Epoch 12 Batch 150 Loss 1.4041


202it [00:25,  7.98it/s]

 Epoch 12 Batch 200 Loss 1.4004


252it [00:31,  8.09it/s]

 Epoch 12 Batch 250 Loss 1.3126


302it [00:37,  8.16it/s]

 Epoch 12 Batch 300 Loss 1.4142


312it [00:39,  7.99it/s]
1it [00:00,  7.93it/s]

 Epoch 12 Loss 1.3997

 Epoch 13 Batch 0 Loss 1.3386


52it [00:06,  8.01it/s]

 Epoch 13 Batch 50 Loss 1.3470


102it [00:12,  7.83it/s]

 Epoch 13 Batch 100 Loss 1.3194


152it [00:19,  7.85it/s]

 Epoch 13 Batch 150 Loss 1.3806


202it [00:25,  7.83it/s]

 Epoch 13 Batch 200 Loss 1.3623


252it [00:31,  7.93it/s]

 Epoch 13 Batch 250 Loss 1.2852


302it [00:38,  8.07it/s]

 Epoch 13 Batch 300 Loss 1.3876


312it [00:39,  7.92it/s]
1it [00:00,  8.28it/s]

 Epoch 13 Loss 1.3688

 Epoch 14 Batch 0 Loss 1.3044


52it [00:06,  7.79it/s]

 Epoch 14 Batch 50 Loss 1.3214


102it [00:12,  7.54it/s]

 Epoch 14 Batch 100 Loss 1.2905


152it [00:19,  7.82it/s]

 Epoch 14 Batch 150 Loss 1.3514


202it [00:25,  8.05it/s]

 Epoch 14 Batch 200 Loss 1.3309


252it [00:31,  8.08it/s]

 Epoch 14 Batch 250 Loss 1.2623


302it [00:38,  8.27it/s]

 Epoch 14 Batch 300 Loss 1.3525


312it [00:39,  7.92it/s]
1it [00:00,  8.44it/s]

 Epoch 14 Loss 1.3379

 Epoch 15 Batch 0 Loss 1.2756


52it [00:06,  7.94it/s]

 Epoch 15 Batch 50 Loss 1.2924


102it [00:12,  8.29it/s]

 Epoch 15 Batch 100 Loss 1.2515


152it [00:18,  8.04it/s]

 Epoch 15 Batch 150 Loss 1.3209


202it [00:24,  8.15it/s]

 Epoch 15 Batch 200 Loss 1.3090


252it [00:30,  8.25it/s]

 Epoch 15 Batch 250 Loss 1.2338


302it [00:37,  7.96it/s]

 Epoch 15 Batch 300 Loss 1.3196


312it [00:38,  8.06it/s]
1it [00:00,  8.02it/s]

 Epoch 15 Loss 1.3076

 Epoch 16 Batch 0 Loss 1.2425


52it [00:06,  8.12it/s]

 Epoch 16 Batch 50 Loss 1.2580


102it [00:12,  7.45it/s]

 Epoch 16 Batch 100 Loss 1.2187


152it [00:19,  7.47it/s]

 Epoch 16 Batch 150 Loss 1.2887


202it [00:26,  8.09it/s]

 Epoch 16 Batch 200 Loss 1.2713


252it [00:32,  7.40it/s]

 Epoch 16 Batch 250 Loss 1.2116


302it [00:39,  7.52it/s]

 Epoch 16 Batch 300 Loss 1.2994


312it [00:40,  7.70it/s]
1it [00:00,  7.95it/s]

 Epoch 16 Loss 1.2781

 Epoch 17 Batch 0 Loss 1.2080


52it [00:06,  7.46it/s]

 Epoch 17 Batch 50 Loss 1.2314


102it [00:13,  7.72it/s]

 Epoch 17 Batch 100 Loss 1.1913


152it [00:19,  7.53it/s]

 Epoch 17 Batch 150 Loss 1.2645


202it [00:26,  7.82it/s]

 Epoch 17 Batch 200 Loss 1.2416


252it [00:32,  7.88it/s]

 Epoch 17 Batch 250 Loss 1.1827


302it [00:39,  7.68it/s]

 Epoch 17 Batch 300 Loss 1.2569


312it [00:40,  7.68it/s]
1it [00:00,  8.10it/s]

 Epoch 17 Loss 1.2492

 Epoch 18 Batch 0 Loss 1.1720


52it [00:06,  8.02it/s]

 Epoch 18 Batch 50 Loss 1.2043


102it [00:13,  7.94it/s]

 Epoch 18 Batch 100 Loss 1.1631


152it [00:19,  7.80it/s]

 Epoch 18 Batch 150 Loss 1.2191


202it [00:25,  8.05it/s]

 Epoch 18 Batch 200 Loss 1.2196


252it [00:32,  7.83it/s]

 Epoch 18 Batch 250 Loss 1.1508


302it [00:38,  7.42it/s]

 Epoch 18 Batch 300 Loss 1.2303


312it [00:39,  7.81it/s]
1it [00:00,  8.34it/s]

 Epoch 18 Loss 1.2200

 Epoch 19 Batch 0 Loss 1.1428


52it [00:06,  7.70it/s]

 Epoch 19 Batch 50 Loss 1.1756


102it [00:13,  7.84it/s]

 Epoch 19 Batch 100 Loss 1.1307


152it [00:19,  7.83it/s]

 Epoch 19 Batch 150 Loss 1.1791


202it [00:25,  7.80it/s]

 Epoch 19 Batch 200 Loss 1.1906


252it [00:32,  7.82it/s]

 Epoch 19 Batch 250 Loss 1.1285


302it [00:38,  7.88it/s]

 Epoch 19 Batch 300 Loss 1.2011


312it [00:39,  7.85it/s]


 Epoch 19 Loss 1.1917



In [None]:
# same as above, another 20 epochs

2it [00:00,  6.24it/s]

 Epoch 20 Batch 0 Loss 1.1127


52it [00:06,  8.01it/s]

 Epoch 20 Batch 50 Loss 1.1451


102it [00:12,  8.19it/s]

 Epoch 20 Batch 100 Loss 1.1086


152it [00:19,  7.86it/s]

 Epoch 20 Batch 150 Loss 1.1572


202it [00:25,  7.77it/s]

 Epoch 20 Batch 200 Loss 1.1617


252it [00:31,  8.01it/s]

 Epoch 20 Batch 250 Loss 1.1142


302it [00:38,  7.91it/s]

 Epoch 20 Batch 300 Loss 1.1720


312it [00:39,  7.89it/s]
1it [00:00,  7.30it/s]

 Epoch 20 Loss 1.1645

 Epoch 21 Batch 0 Loss 1.0839


52it [00:06,  7.90it/s]

 Epoch 21 Batch 50 Loss 1.1272


102it [00:13,  7.91it/s]

 Epoch 21 Batch 100 Loss 1.0841


152it [00:19,  7.79it/s]

 Epoch 21 Batch 150 Loss 1.1298


202it [00:25,  7.76it/s]

 Epoch 21 Batch 200 Loss 1.1493


252it [00:32,  7.88it/s]

 Epoch 21 Batch 250 Loss 1.0942


302it [00:38,  7.92it/s]

 Epoch 21 Batch 300 Loss 1.1515


312it [00:39,  7.89it/s]
1it [00:00,  7.98it/s]

 Epoch 21 Loss 1.1393

 Epoch 22 Batch 0 Loss 1.0674


52it [00:06,  8.05it/s]

 Epoch 22 Batch 50 Loss 1.1029


102it [00:12,  8.15it/s]

 Epoch 22 Batch 100 Loss 1.0602


152it [00:19,  7.78it/s]

 Epoch 22 Batch 150 Loss 1.1036


202it [00:25,  8.10it/s]

 Epoch 22 Batch 200 Loss 1.1208


252it [00:31,  8.04it/s]

 Epoch 22 Batch 250 Loss 1.0704


302it [00:37,  7.94it/s]

 Epoch 22 Batch 300 Loss 1.1407


312it [00:39,  8.00it/s]
1it [00:00,  8.10it/s]

 Epoch 22 Loss 1.1155

 Epoch 23 Batch 0 Loss 1.0359


52it [00:06,  7.85it/s]

 Epoch 23 Batch 50 Loss 1.0880


102it [00:12,  7.89it/s]

 Epoch 23 Batch 100 Loss 1.0283


152it [00:19,  8.02it/s]

 Epoch 23 Batch 150 Loss 1.0782


202it [00:25,  8.02it/s]

 Epoch 23 Batch 200 Loss 1.0939


252it [00:31,  7.94it/s]

 Epoch 23 Batch 250 Loss 1.0484


302it [00:37,  8.01it/s]

 Epoch 23 Batch 300 Loss 1.1141


312it [00:39,  7.96it/s]
1it [00:00,  8.49it/s]

 Epoch 23 Loss 1.0924

 Epoch 24 Batch 0 Loss 1.0290


52it [00:06,  8.12it/s]

 Epoch 24 Batch 50 Loss 1.0669


102it [00:12,  8.25it/s]

 Epoch 24 Batch 100 Loss 1.0100


152it [00:19,  8.02it/s]

 Epoch 24 Batch 150 Loss 1.0611


202it [00:25,  8.02it/s]

 Epoch 24 Batch 200 Loss 1.0721


252it [00:31,  8.14it/s]

 Epoch 24 Batch 250 Loss 1.0262


302it [00:37,  7.74it/s]

 Epoch 24 Batch 300 Loss 1.0871


312it [00:38,  8.01it/s]
1it [00:00,  7.98it/s]

 Epoch 24 Loss 1.0689

 Epoch 25 Batch 0 Loss 0.9967


52it [00:06,  8.01it/s]

 Epoch 25 Batch 50 Loss 1.0444


102it [00:12,  7.95it/s]

 Epoch 25 Batch 100 Loss 0.9829


152it [00:18,  8.03it/s]

 Epoch 25 Batch 150 Loss 1.0416


202it [00:25,  8.18it/s]

 Epoch 25 Batch 200 Loss 1.0393


252it [00:31,  7.84it/s]

 Epoch 25 Batch 250 Loss 1.0004


302it [00:37,  7.94it/s]

 Epoch 25 Batch 300 Loss 1.0552


312it [00:39,  7.99it/s]
1it [00:00,  7.95it/s]

 Epoch 25 Loss 1.0449

 Epoch 26 Batch 0 Loss 0.9709


52it [00:06,  8.08it/s]

 Epoch 26 Batch 50 Loss 1.0301


102it [00:12,  8.20it/s]

 Epoch 26 Batch 100 Loss 0.9525


152it [00:18,  8.09it/s]

 Epoch 26 Batch 150 Loss 1.0144


202it [00:25,  7.73it/s]

 Epoch 26 Batch 200 Loss 1.0158


252it [00:31,  8.02it/s]

 Epoch 26 Batch 250 Loss 0.9742


302it [00:37,  8.15it/s]

 Epoch 26 Batch 300 Loss 1.0370


312it [00:38,  8.04it/s]
1it [00:00,  8.33it/s]

 Epoch 26 Loss 1.0220

 Epoch 27 Batch 0 Loss 0.9540


52it [00:06,  8.17it/s]

 Epoch 27 Batch 50 Loss 1.0062


102it [00:12,  7.95it/s]

 Epoch 27 Batch 100 Loss 0.9395


152it [00:18,  8.15it/s]

 Epoch 27 Batch 150 Loss 0.9888


202it [00:25,  7.80it/s]

 Epoch 27 Batch 200 Loss 0.9984


252it [00:31,  7.82it/s]

 Epoch 27 Batch 250 Loss 0.9565


302it [00:37,  7.83it/s]

 Epoch 27 Batch 300 Loss 1.0029


312it [00:39,  8.00it/s]
1it [00:00,  8.01it/s]

 Epoch 27 Loss 1.0004

 Epoch 28 Batch 0 Loss 0.9223


52it [00:06,  7.74it/s]

 Epoch 28 Batch 50 Loss 0.9932


102it [00:12,  8.01it/s]

 Epoch 28 Batch 100 Loss 0.9139


152it [00:19,  8.18it/s]

 Epoch 28 Batch 150 Loss 0.9643


202it [00:25,  8.25it/s]

 Epoch 28 Batch 200 Loss 0.9871


252it [00:31,  7.88it/s]

 Epoch 28 Batch 250 Loss 0.9354


302it [00:37,  7.78it/s]

 Epoch 28 Batch 300 Loss 0.9785


312it [00:39,  7.97it/s]
1it [00:00,  8.02it/s]

 Epoch 28 Loss 0.9767

 Epoch 29 Batch 0 Loss 0.8963


52it [00:06,  7.89it/s]

 Epoch 29 Batch 50 Loss 0.9730


102it [00:12,  7.85it/s]

 Epoch 29 Batch 100 Loss 0.8896


152it [00:19,  7.94it/s]

 Epoch 29 Batch 150 Loss 0.9471


202it [00:25,  8.05it/s]

 Epoch 29 Batch 200 Loss 0.9658


252it [00:31,  8.13it/s]

 Epoch 29 Batch 250 Loss 0.9161


302it [00:37,  7.92it/s]

 Epoch 29 Batch 300 Loss 0.9587


312it [00:39,  7.97it/s]
1it [00:00,  7.99it/s]

 Epoch 29 Loss 0.9554

 Epoch 30 Batch 0 Loss 0.8742


52it [00:06,  8.09it/s]

 Epoch 30 Batch 50 Loss 0.9482


102it [00:12,  7.98it/s]

 Epoch 30 Batch 100 Loss 0.8749


152it [00:18,  7.99it/s]

 Epoch 30 Batch 150 Loss 0.9259


202it [00:25,  7.96it/s]

 Epoch 30 Batch 200 Loss 0.9438


252it [00:31,  8.03it/s]

 Epoch 30 Batch 250 Loss 0.8988


302it [00:37,  8.05it/s]

 Epoch 30 Batch 300 Loss 0.9431


312it [00:38,  8.02it/s]
1it [00:00,  8.22it/s]

 Epoch 30 Loss 0.9356

 Epoch 31 Batch 0 Loss 0.8569


52it [00:06,  8.01it/s]

 Epoch 31 Batch 50 Loss 0.9228


102it [00:12,  8.14it/s]

 Epoch 31 Batch 100 Loss 0.8569


152it [00:18,  7.85it/s]

 Epoch 31 Batch 150 Loss 0.9217


202it [00:25,  8.09it/s]

 Epoch 31 Batch 200 Loss 0.9215


252it [00:31,  7.86it/s]

 Epoch 31 Batch 250 Loss 0.8734


302it [00:37,  8.04it/s]

 Epoch 31 Batch 300 Loss 0.9197


312it [00:38,  8.01it/s]
1it [00:00,  8.20it/s]

 Epoch 31 Loss 0.9166

 Epoch 32 Batch 0 Loss 0.8356


52it [00:06,  7.99it/s]

 Epoch 32 Batch 50 Loss 0.9027


102it [00:12,  8.31it/s]

 Epoch 32 Batch 100 Loss 0.8476


152it [00:18,  8.10it/s]

 Epoch 32 Batch 150 Loss 0.9108


202it [00:25,  8.07it/s]

 Epoch 32 Batch 200 Loss 0.9074


252it [00:31,  7.96it/s]

 Epoch 32 Batch 250 Loss 0.8602


302it [00:37,  8.04it/s]

 Epoch 32 Batch 300 Loss 0.9015


312it [00:39,  7.98it/s]
1it [00:00,  8.09it/s]

 Epoch 32 Loss 0.8977

 Epoch 33 Batch 0 Loss 0.8218


52it [00:06,  8.07it/s]

 Epoch 33 Batch 50 Loss 0.8850


102it [00:12,  8.06it/s]

 Epoch 33 Batch 100 Loss 0.8220


152it [00:19,  7.86it/s]

 Epoch 33 Batch 150 Loss 0.8912


202it [00:25,  8.12it/s]

 Epoch 33 Batch 200 Loss 0.9050


252it [00:31,  7.86it/s]

 Epoch 33 Batch 250 Loss 0.8529


302it [00:37,  7.93it/s]

 Epoch 33 Batch 300 Loss 0.8808


312it [00:39,  8.00it/s]
1it [00:00,  7.62it/s]

 Epoch 33 Loss 0.8812

 Epoch 34 Batch 0 Loss 0.8031


52it [00:06,  7.82it/s]

 Epoch 34 Batch 50 Loss 0.8631


102it [00:12,  8.03it/s]

 Epoch 34 Batch 100 Loss 0.8011


152it [00:19,  7.95it/s]

 Epoch 34 Batch 150 Loss 0.8706


202it [00:25,  7.98it/s]

 Epoch 34 Batch 200 Loss 0.8952


252it [00:31,  7.96it/s]

 Epoch 34 Batch 250 Loss 0.8344


302it [00:37,  7.95it/s]

 Epoch 34 Batch 300 Loss 0.8624


312it [00:39,  7.98it/s]
1it [00:00,  8.19it/s]

 Epoch 34 Loss 0.8641

 Epoch 35 Batch 0 Loss 0.7796


52it [00:06,  7.76it/s]

 Epoch 35 Batch 50 Loss 0.8480


102it [00:12,  8.18it/s]

 Epoch 35 Batch 100 Loss 0.7764


152it [00:19,  7.94it/s]

 Epoch 35 Batch 150 Loss 0.8484


202it [00:25,  8.00it/s]

 Epoch 35 Batch 200 Loss 0.8750


252it [00:31,  7.47it/s]

 Epoch 35 Batch 250 Loss 0.8116


302it [00:37,  7.94it/s]

 Epoch 35 Batch 300 Loss 0.8438


312it [00:39,  7.97it/s]
1it [00:00,  7.65it/s]

 Epoch 35 Loss 0.8464

 Epoch 36 Batch 0 Loss 0.7605


52it [00:06,  8.20it/s]

 Epoch 36 Batch 50 Loss 0.8378


102it [00:12,  8.13it/s]

 Epoch 36 Batch 100 Loss 0.7629


152it [00:19,  8.27it/s]

 Epoch 36 Batch 150 Loss 0.8258


202it [00:25,  8.04it/s]

 Epoch 36 Batch 200 Loss 0.8480


252it [00:31,  7.75it/s]

 Epoch 36 Batch 250 Loss 0.7990


302it [00:38,  7.83it/s]

 Epoch 36 Batch 300 Loss 0.8285


312it [00:39,  7.94it/s]
1it [00:00,  8.02it/s]

 Epoch 36 Loss 0.8277

 Epoch 37 Batch 0 Loss 0.7436


52it [00:06,  8.12it/s]

 Epoch 37 Batch 50 Loss 0.8094


102it [00:12,  7.80it/s]

 Epoch 37 Batch 100 Loss 0.7361


152it [00:18,  8.02it/s]

 Epoch 37 Batch 150 Loss 0.8167


202it [00:25,  7.88it/s]

 Epoch 37 Batch 200 Loss 0.8331


252it [00:31,  8.17it/s]

 Epoch 37 Batch 250 Loss 0.7703


302it [00:37,  7.98it/s]

 Epoch 37 Batch 300 Loss 0.8242


312it [00:38,  8.06it/s]
1it [00:00,  7.81it/s]

 Epoch 37 Loss 0.8116

 Epoch 38 Batch 0 Loss 0.7324


52it [00:06,  8.13it/s]

 Epoch 38 Batch 50 Loss 0.7970


102it [00:12,  8.24it/s]

 Epoch 38 Batch 100 Loss 0.7213


152it [00:18,  8.24it/s]

 Epoch 38 Batch 150 Loss 0.7996


202it [00:24,  8.20it/s]

 Epoch 38 Batch 200 Loss 0.8293


252it [00:31,  8.29it/s]

 Epoch 38 Batch 250 Loss 0.7519


302it [00:37,  8.09it/s]

 Epoch 38 Batch 300 Loss 0.8032


312it [00:38,  8.10it/s]
1it [00:00,  7.03it/s]

 Epoch 38 Loss 0.7925

 Epoch 39 Batch 0 Loss 0.7173


52it [00:06,  7.96it/s]

 Epoch 39 Batch 50 Loss 0.7712


102it [00:12,  8.18it/s]

 Epoch 39 Batch 100 Loss 0.7005


152it [00:18,  8.10it/s]

 Epoch 39 Batch 150 Loss 0.7776


202it [00:24,  7.86it/s]

 Epoch 39 Batch 200 Loss 0.8220


252it [00:30,  8.11it/s]

 Epoch 39 Batch 250 Loss 0.7483


302it [00:37,  8.08it/s]

 Epoch 39 Batch 300 Loss 0.7880


312it [00:38,  8.13it/s]


 Epoch 39 Loss 0.7753



## Inference Experiments
We need an encoder and a decoder for inference. They use the weights saved above.

In [None]:
# load the trained model for inference

def load_models(vocab_size, emb_size, hidden_units):
    encoder = Encoder(vocab_size, emb_size, hidden_units, 
                    1, 0)
    decoder = Decoder(VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_UNITS, 
                    1, 0)
    enc_hidden = encoder.initial_hidden_state()
    place_holder = tf.convert_to_tensor([[1]])
    enc_output, enc_hidden = encoder(place_holder, enc_hidden, training=False)
    decoder(place_holder, enc_hidden, enc_output, training=False)
    encoder.load_weights('enc_weights.h5')
    decoder.load_weights('dec_weights.h5')
    return encoder, decoder

inf_enc, inf_dec = load_models(VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_UNITS)

### Repetition
It turns out that the decoder doesn't always produce an `END-OF-SEQ` token, but produces repetition.
Thus the inference function **falls back** to [Floyd's algorithm](https://en.wikipedia.org/wiki/Cycle_detection) (which is of $O(n)$ time complexity) to manually detect cycles in the output sequence only when there is no `END-OF-SEQ`.

In [None]:
def floyd(L):
    '''Floyd's cycle-finding algorithm'''

    for i in range(1, len(L)):
        if 2 * i < len(L) and L[i] == L[2 * i]:
            v = i; break
    else:
        return slice(0, len(L))

    for i in range(len(L)):
        if L[i] == L[v + i]:
            mu = i; break
    
    for i in range(1, len(L) - mu):
        if L[mu] == L[mu + i]:
            lam = i; break

    return slice(mu, mu + lam)


def infer(sent, max_len=40, doclean=True, decycle=True):   
    inputs = texts2ids([sent], tokenizer, doclean=doclean)
    enc_hidden = inf_enc.initial_hidden_state()
    enc_output, enc_hidden = inf_enc(inputs, enc_hidden, training=False)

    dec_hidden = enc_hidden
    dec_input = tf.expand_dims(START_TOK, 1)
    result = []
    for i in range(max_len):
        pred, dec_hidden, _ = inf_dec(
            dec_input, dec_hidden, enc_output, training=False)
        pred_id = tf.argmax(pred, axis=1).numpy()
        
        if pred_id[0] == END_TOK[0]:
            break
        else:
            result.append(pred_id)
            dec_input = tf.expand_dims(pred_id, 1)
    
    for sent in np.stack(result, axis=1):
        if decycle: sent = sent.tolist()[floyd(sent)]
        sent = [t for t in sent if t < tokenizer.vocab_size]
        return tokenizer.decode(sent)

In [None]:
test_sents = ['Hi',
              'How are you?',
              "What's your name?",
              'Tell me about yourself',
              'Do you love me?',
              "What's the meaning of life?",
              'How is the weather today?',
              "Let's have a dinner",
              'Are you a bot?',
              'Are you a robot?',
              'Why not?',
              'Harry Potter?']

for s in test_sents:
    print('Q: %s\nA: %s\n' % (s, infer(s)))

Q: Hi
A: hi . what are you talkin about ?

Q: How are you?
A: i don t know what you re a little taste of you to get excited . 

Q: What's your name?
A: don t know . you 

Q: Tell me about yourself
A: i ve been trails .

Q: Do you love me?
A: no . . . 

Q: What's the meaning of life?
A: what the hell are you ?

Q: How is the weather today?
A: you re a little rules these men 

Q: Let's have a dinner
A: it s not anymore . 

Q: Are you a bot?
A: he s not going . 

Q: Are you a robot?
A: yeah . it was in the world .

Q: Why not?
A: i m not a little acting . 

Q: Harry Potter?
A: that ain t that he was in the world .



### Quality
* The responses are generally meaningful grammatical, with a correct use of punctuations.
* "_are you a bot?_" got a weird answer, but "_are you a robot_" got a better one.
* The chatbot is sometimes rude ("_what's the meaning of life_" -> "_what the hell are you?_")
* Here's what happens if we __don't__ manually detect cycles.

In [None]:
for s in (test_sents[2], test_sents[4]):
    print('Q: %s\nA: %s\n' % (s, infer(s, decycle=False)))

Q: What's your name?
A: i don t know . you don t know . you don t know . you don t know . you don t know . you don t know . you don t know . you don t know . 

Q: Do you love me?
A: no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . no . . . 



### Comparison with baseline
Chatterbot is used as the baseline: some of its answers are not as good as ours!

In [None]:
def get_baseline():
    from chatterbot import ChatBot
    bot = ChatBot('Bot', trainer='chatterbot.trainers.ChatterBotCorpusTrainer')
    bot.train("chatterbot.corpus.english")
    
    return lambda x: bot.get_response(x)

baseline_bot = get_baseline()

In [None]:
for s in test_sents:
    print('Q: %s\nA: %s\n' % (s, baseline_bot(s)))

Q: Hi
A: How are you doing?

Q: How are you?
A: Who? Who is but a form following the function of what

Q: What's your name?
A: I am still young

Q: Tell me about yourself
A: What do you want to know?

Q: Do you love me?
A: I could get mad.

Q: What's the meaning of life?
A: is a literary reference to the illuminatus.

Q: How is the weather today?
A: Maradona is great. Sinsemillia is even better.

Q: Let's have a dinner
A: What is your question?

Q: Are you a bot?
A: Yes I am.

Q: Are you a robot?
A: Yes I am.

Q: Why not?
A: i support the 2nd amendment.

Q: Harry Potter?
A: I am a programmer

