# Imports

In [None]:
import numpy as np
from unicodedata import normalize
from pprint import pprint
import string
import re
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

# import plaidml.keras
# plaidml.keras.install_backend()
# import os
# os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
# import plaidml.keras.backend as K
import keras
import keras.backend as K
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, LSTM, Concatenate, Input, Embedding, TimeDistributed, Flatten, Dropout, RepeatVector
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.text import Tokenizer

# Reading movie lines

In [None]:
table = str.maketrans('', '', string.punctuation)
# prepare regex for char filtering
re_print = re.compile('[^%s]' % re.escape(string.printable))
# prepare translation table for removing punctuation
table = str.maketrans('', '', string.punctuation)

def clean_sentence(line):
    line = line.strip().replace('--', '').replace("  ", " ").replace('"', "")
    line = normalize('NFD', line).encode('ascii', 'ignore')
    line = line.decode('UTF-8')
    # tokenize on white space
    line = line.split()
    # convert to lowercase
    line = [word.lower() for word in line]
    # remove punctuation from each token
    line = [word.translate(table) for word in line]
    # remove non-printable chars form each token
    line = [re_print.sub('', w) for w in line]
    # remove tokens with numbers in them
    line = [word for word in line if word.isalpha()]
    return ' '.join(line)

with open('./cornell-movie-dialogs-corpus/movie_lines.txt', 'r', errors='ignore') as f:
    lines_as_list = [row.strip() for row in f.readlines()]


lines = {}
for line in lines_as_list:
    lines[
        line.split('+++$+++')[0].strip()
    ] = clean_sentence(line.split('+++$+++')[-1])  # clean sentences

del lines_as_list

with open('./cornell-movie-dialogs-corpus/movie_conversations.txt', 'r', errors='ignore') as f:
    conversations = [row.strip() for row in f.readlines()]

# only take id's and convert list as string to list as list
conversations = [
    conversation.split('+++$+++')[-1].strip().replace('[', '').replace(']', '').replace("'", '').replace(" ", '').split(',') 
    for conversation in conversations
]

pprint({k: lines[k] for k in list(lines)[:10]})
print()
pprint(conversations[:10])

assert len([conversation for conversation in conversations if len(conversation) <=1]) == 0


# map keys to line

In [None]:
conversations_with_lines = []
for conversation in conversations:
    conversations_with_lines.append([lines[key] for key in conversation])
    
pprint(conversations_with_lines[100:110])

# Pair those things

In [None]:
def pair_it(my_list):
    pairs = []
    for i in range(len(my_list) -1):
        pairs.append([my_list[i], my_list[i + 1]])
    return pairs

paired_conversations_agg = [
    pair_it(conversation) for conversation in conversations_with_lines
]
conversations_pairs = np.array([item for sublist in paired_conversations_agg for item in sublist])
for i in range(10):
    pprint(conversations_pairs[i])

# Noise reduction

In [None]:
hist, edges  = np.histogram([len(question.split(' ')) + len(answer.split(' ')) for question, answer in conversations_pairs], density=True, bins=100)
center = (edges[:-1] + edges[1:]) / 2
f, ax = plt.subplots(figsize=(8, 8))
plt.xlabel('all conversations', fontsize=14)
plt.bar(center, hist, align='center', width=(edges[1] - edges[0]) * .8)
plt.grid()
plt.show()

longest_converastion = conversations_pairs[np.array([len(question.split(' ')) + len(answer.split(' ')) for question, answer in conversations_pairs]).argmax()]
print("longest conversation: \n{}\n".format(longest_converastion))
print("longest conversation has {} words.".format(len(longest_converastion[0].split(' ')) + len(longest_converastion[1].split(' '))))
max_sentence_lenght = 10  # maximum alowed converastion lenght in words
clensed_conversations = np.array([conversation_pair for conversation_pair in conversations_pairs  if len(conversation_pair[0].split(' ')) < max_sentence_lenght and len(conversation_pair[1].split(' '))  < max_sentence_lenght ])
print("filetered {} conversations\n".format(len(conversations_pairs) - len(clensed_conversations)))

hist, edges  = np.histogram([len(question.split(' ')) + len(answer.split(' ')) for question, answer in clensed_conversations], density=True, bins=100)
center = (edges[:-1] + edges[1:]) / 2
f, ax = plt.subplots(figsize=(8, 8))
plt.bar(center, hist, align='center', width=(edges[1] - edges[0]) * .8)
plt.xlabel('reduced conversations', fontsize=14)
plt.grid()
plt.show()
longest_converastion = clensed_conversations[np.array([len(question.split(' ')) + len(answer.split(' ')) for question, answer in clensed_conversations]).argmax()]
print("longest conversation in reduce dataset: \n{}\n".format(longest_converastion))
print("longest conversation in reduce dataset has {} words.".format(len(longest_converastion[0].split(' ')) + len(longest_converastion[1].split(' '))))



# Shity magic

In [None]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(clensed_conversations.reshape(-1))
vocabulary = np.fromiter(tokenizer.word_index.keys(), dtype="<U34")
all_lines = clensed_conversations.reshape(-1)

max_sentence_lenght = max(len(line.split()) for line in all_lines)
print("first word: {}, last word: {}".format(vocabulary[0], vocabulary[-1]))
print('max sentence lenght: {} words'.format(max_sentence_lenght))
print('vocab_size: {} words'.format(len(vocabulary)))

encoder_input_data = np.zeros((len(clensed_conversations), max_sentence_lenght, len(vocabulary)), dtype='uint8')
decoder_input_data = np.zeros((len(clensed_conversations), max_sentence_lenght, len(vocabulary)), dtype='uint8')
decoder_target_data = np.zeros((len(clensed_conversations), max_sentence_lenght, len(vocabulary)), dtype='uint8')

with tqdm_notebook(total=len(clensed_conversations)) as pbar:
    for i, (left, right) in enumerate(clensed_conversations):
        for t, (word) in enumerate(left.split()):
            encoder_input_data[i, t, np.where(vocabulary == word)[0][0]] = 1
        for t, (word) in enumerate(right.split()):
            decoder_input_data[i, t, np.where(vocabulary == word)[0][0]] = 1
            if t > 0:
                decoder_target_data[i, t-1, np.where(vocabulary == word)[0][0]] = 1
        pbar.update(1)

# Define Model
![sequential](./images/seq2seq.jpg)

In [None]:
K.clear_session()
latent_dim = 1024

encoder_inputs = Input(shape=(None, len(vocabulary)))

e_outputs, h1, c1 = LSTM(latent_dim, return_state=True, return_sequences=True)(encoder_inputs) 
_, h2, c2 = LSTM(latent_dim, return_state=True)(e_outputs) 
encoder_states = [h1, c1, h2, c2]

decoder_inputs = Input(shape=(None, len(vocabulary)))

out_layer1 = LSTM(latent_dim, return_sequences=True, return_state=True)
d_outputs, dh1, dc1 = out_layer1(decoder_inputs,initial_state= [h1, c1])
out_layer2 = LSTM(latent_dim, return_sequences=True, return_state=True)
final, dh2, dc2 = out_layer2(d_outputs, initial_state= [h2, c2])
decoder_dense = Dense(len(vocabulary), activation='softmax')
decoder_outputs = decoder_dense(final)

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])
model.summary()


# Training

In [None]:
# model.load_weights("mount-this/seq2seq-model.h5")

In [None]:
batch_size = 64
epochs = 20

filename = 'mount-this/seq2seq-model.h5'
checkpoint = ModelCheckpoint(filename, verbose=1, save_best_only=True)
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          callbacks=[checkpoint],
          validation_split=0.2)

# Model restore

In [None]:
# model = load_model("mount-this/seq2seq-model.h5")
# # model = load_model("mount-this/seq2seq-overfited-model.h5")

# encoder_inputs = model.input[0]   # input_1
# encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output   # lstm_1
# encoder_states = [state_h_enc, state_c_enc]
# encoder_model = Model(encoder_inputs, encoder_states)

# decoder_inputs = model.input[1]   # input_2
# decoder_state_input_h = Input(shape=(latent_dim,), name='input_3')
# decoder_state_input_c = Input(shape=(latent_dim,), name='input_4')
# decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
# decoder_lstm = model.layers[3]
# decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
#     decoder_inputs, initial_state=decoder_states_inputs)
# decoder_states = [state_h_dec, state_c_dec]
# decoder_dense = model.layers[4]
# decoder_outputs = decoder_dense(decoder_outputs)
# decoder_model = Model(
#     [decoder_inputs] + decoder_states_inputs,
#     [decoder_outputs] + decoder_states)

# Inference

In [None]:
to_infer = "hello"
encoded_infer = np.zeros((max_sentence_lenght, len(vocabulary)), dtype='uint8')
# decoded_infer = np.zeros((max_sentence_lenght, len(vocabulary)), dtype='uint8')
for i in range(len(to_infer.split())):
    encoded_infer[i, np.where(vocabulary == to_infer.split()[i])[0][0]] = 1
encoded_infer = encoded_infer.reshape((1, encoded_infer.shape[0], encoded_infer.shape[1]))

res = encoder_model.predict(encoded_infer)
sentence = " ".join([vocabulary[np.argmax(word_indexes)] for word_indexes in res[0] if np.max(word_indexes) > .2])
sentence