# Imports

In [None]:
import numpy as np
from unicodedata import normalize
from pprint import pprint
import string
import re
import matplotlib.pyplot as plt
from keras.backend import clear_session
from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, CuDNNLSTM, Input, Embedding, TimeDistributed, Flatten, Dropout, RepeatVector
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.models import load_model

# Reading movie lines

In [None]:
table = str.maketrans('', '', string.punctuation)
# prepare regex for char filtering
re_print = re.compile('[^%s]' % re.escape(string.printable))
# prepare translation table for removing punctuation
table = str.maketrans('', '', string.punctuation)

def clean_sentence(line):
    line = line.strip().replace('--', '').replace("  ", " ").replace('"', "")
    line = normalize('NFD', line).encode('ascii', 'ignore')
    line = line.decode('UTF-8')
    # tokenize on white space
    line = line.split()
    # convert to lowercase
    line = [word.lower() for word in line]
    # remove punctuation from each token
    line = [word.translate(table) for word in line]
    # remove non-printable chars form each token
    line = [re_print.sub('', w) for w in line]
    # remove tokens with numbers in them
    line = [word for word in line if word.isalpha()]
    return ' '.join(line)

with open('./cornell-movie-dialogs-corpus/movie_lines.txt', 'r', errors='ignore') as f:
    lines_as_list = [row.strip() for row in f.readlines()]


lines = {}
for line in lines_as_list:
    lines[
        line.split('+++$+++')[0].strip()
    ] = clean_sentence(line.split('+++$+++')[-1])  # clean sentences

del lines_as_list

with open('./cornell-movie-dialogs-corpus/movie_conversations.txt', 'r', errors='ignore') as f:
    conversations = [row.strip() for row in f.readlines()]

# only take id's and convert list as string to list as list
conversations = [
    conversation.split('+++$+++')[-1].strip().replace('[', '').replace(']', '').replace("'", '').replace(" ", '').split(',') 
    for conversation in conversations
]

pprint({k: lines[k] for k in list(lines)[:10]})
print()
pprint(conversations[:10])

assert len([conversation for conversation in conversations if len(conversation) <=1]) == 0


# map keys to line

In [None]:
conversations_with_lines = []
for conversation in conversations:
    conversations_with_lines.append([lines[key] for key in conversation])
    
pprint(conversations_with_lines[100:110])

# Pair those things

In [None]:
def pair_it(my_list):
    pairs = []
    for i in range(len(my_list) -1):
        pairs.append([my_list[i], my_list[i + 1]])
    return pairs

paired_conversations_agg = [
    pair_it(conversation) for conversation in conversations_with_lines
]
conversations_pairs = np.array([item for sublist in paired_conversations_agg for item in sublist])
for i in range(10):
    pprint(conversations_pairs[i])

# Noise reduction

In [None]:
hist, edges  = np.histogram([len(question.split(' ')) + len(answer.split(' ')) for question, answer in conversations_pairs], density=True, bins=100)
center = (edges[:-1] + edges[1:]) / 2
f, ax = plt.subplots(figsize=(10, 10))
plt.bar(center, hist, align='center', width=(edges[1] - edges[0]) * .8)
plt.grid()
plt.show()

longest_converastion = conversations_pairs[np.array([len(question.split(' ')) + len(answer.split(' ')) for question, answer in conversations_pairs]).argmax()]
print("longest conversation: \n{}".format(longest_converastion))
print("\nlongest conversation has {} words".format(len(longest_converastion[0].split(' ')) + len(longest_converastion[1].split(' '))))
max_conversation_lenght = 38  # maximum alowed converastion lenght in words
clensed_conversations = np.array([conversation_pair for conversation_pair in conversations_pairs  if (len(conversation_pair[0].split(' ')) + len(conversation_pair[1].split(' '))) < max_conversation_lenght])
print("filetered {} conversations".format(len(conversations_pairs) - len(clensed_conversations)))

# Shity magic

In [None]:
vocabulary = set()
all_lines = clensed_conversations.reshape(-1)
for line in all_lines:
    for word in line.split():
        if word not in vocabulary:
            vocabulary.add(word)
            
vocabulary = np.array(sorted(vocabulary))
max_sentence_lenght = max(len(line.split()) for line in all_lines)
print("first word: {}, last word: {}".format(vocabulary[0], vocabulary[-1]))
print('max sentence lenght: {} words'.format(max_sentence_lenght))
print('vocab_size: {} words'.format(len(vocabulary)))

split_index = int(len(clensed_conversations) * .8)

trainX = np.zeros((len(clensed_conversations[:split_index]), max_sentence_lenght), dtype='uint16')
trainY = np.zeros((len(clensed_conversations[:split_index]), max_sentence_lenght, vocab_size), dtype='uint8')
testX = np.zeros((len(clensed_conversations[split_index:]), max_sentence_lenght), dtype='uint16')
testY = np.zeros((len(clensed_conversations[split_index:]), max_sentence_lenght, vocab_size), dtype='uint8')

print("trainX steps {}".format(len(trainX)))
for i, (sentence) in enumerate(clensed_conversations[:split_index, 0]):
    if i % 10000 == 0:
        print(i)
    for j, (word) in enumerate(sentence.split()):
        trainX[i, j] = np.where(vocabulary == word)[0][0]
        
print("trainY steps {}".format(len(trainY)))
for i, (sentence) in enumerate(clensed_conversations[:split_index, 1]):
    if i % 10000 == 0:
        print(i)
    for j, (word) in enumerate(sentence.split()):
        index_of_word = np.where(vocabulary == word)[0][0]
        trainY[i, j, index_of_word] = 1
        
print("testX steps {}".format(len(testX)))
for i, (sentence) in enumerate(clensed_conversations[split_index:, 0]):
    if i % 10000 == 0:
        print(i)
    for j, (word) in enumerate(sentence.split()):
        testX[i, j] = np.where(vocabulary == word)[0][0]
        
print("testY steps {}".format(len(testY)))
for i, (sentence) in enumerate(clensed_conversations[split_index:, 1]):
    if i % 10000 == 0:
        print(i)
    for j, (word) in enumerate(sentence.split()):
        index_of_word = np.where(vocabulary == word)[0][0]
        testY[i, j, index_of_word] = 1

# Define Model

In [None]:
clear_session()
n_units=256
model = Sequential()
model.add(Embedding(len(vocabulary) + 1, n_units, input_length=max_sentence_lenght, mask_zero=True))
model.add(LSTM(n_units))  # CuDNNLSTM
model.add(RepeatVector(max_sentence_lenght))
model.add(LSTM(n_units, return_sequences=True))  # CuDNNLSTM
model.add(TimeDistributed(Dense(len(vocabulary) + 1, activation='softmax')))

model.compile(optimizer='adam', loss='categorical_crossentropy')
# summarize defined model
model.summary()

# Training

In [None]:
# fit model
filename = 'mount-this/model.h5'
checkpoint = ModelCheckpoint(filename, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
model.fit(trainX, trainY, epochs=30, batch_size=64, validation_data=(testX, testY), callbacks=[checkpoint], verbose=1)

In [None]:
# model = load_model('chatbot-seq-400.h5')

In [None]:
# to_infer = 'how are you'
# source = np.zeros(max_sentence_lenght)

# for i in range(len(to_infer)):
#     source[i] = token_index[to_infer[i]]
# source = source.reshape((1, source.shape[0]))
# res = model.predict(source)
# sentence = "".join([input_characters[np.argmax(char_indxes)] for char_indxes in res[0]])
# print(to_infer + " -> " + sentence.strip())