In [None]:
from __future__ import print_function
import numpy as np
import gensim
import string
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.keras.callbacks import ModelCheckpoint

import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

In [None]:
# Path to the dataset
path = "/content/sherlock_book.txt"

# Data Cleaning and Preprocessing 

# Restriction on the maximum length of sentence in terms of number of words
max_sentence_len = 40

# Open data file and read the data as lines
with open(path) as file_:
  docs = file_.readlines()

# Preprocessing the data
# 1. Converting the words into lower cases
# 2. Removing punctuation marks
# 3. Spliting the sentence into two if the max length exceeds
# 4. Remove the sentences which are empty or too short 
sentences = [[word for word in doc.lower().translate(str.maketrans('','',string.punctuation)).split()[:max_sentence_len]] for doc in docs]
sentences = [item for item in sentences if len(item) > 2]
print('Num of sentences:', len(sentences))

In [None]:
# Training the Word2Vec model
word_model = gensim.models.Word2Vec(sentences, size=100, min_count=1, window=5, iter=100)
pretrained_weights = word_model.wv.vectors
vocab_size, emdedding_size = pretrained_weights.shape

# Embeddings shape will be : word count X size of vector
print('Result embedding shape:', pretrained_weights.shape)

In [None]:
# Saving the model 
model_path = "sherlock_model"
word_model.save(model_path)

In [None]:
# Retriving 'topn' words similar to 'word'
def get_most_similar(word,topn):
  try:
    print(word_model.most_similar(word,topn=topn))
  except:
    print("Word <",word,"> not in vocabulary")

In [None]:
# Method to return index of 'word' in the word2vec model
# Returns -1 if the 'word' is not present in the vocab 
def word2idx(word):
  try:
   index = word_model.wv.vocab[word].index
  except:
   index = -1
  return index   

# Method to return 'word' at a particular index in the word2vec model
# Returns empty string if the index is out of bounds
def idx2word(idx):
  try:
    word = word_model.wv.index2word[idx]
  except:
    word = ""  
  return word

In [None]:
# Preparing the input and coresspondings output for training the LSTMs
train_x = np.zeros([len(sentences), max_sentence_len], dtype=np.int32)
train_y = np.zeros([len(sentences)], dtype=np.int32)
for i, sentence in enumerate(sentences):
  for t, word in enumerate(sentence[:-1]):
    train_x[i, t] = word2idx(word)
  train_y[i] = word2idx(sentence[-1])

In [None]:
# LSTM model 
lstm_model = tf.keras.models.Sequential()
lstm_model.add(tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=emdedding_size, weights=[pretrained_weights]))
lstm_model.add(tf.keras.layers.LSTM(units=emdedding_size))
lstm_model.add(tf.keras.layers.Dense(units=vocab_size))
lstm_model.add(tf.keras.layers.Activation('softmax'))
lstm_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

In [None]:
# Method to generate next samples using a temprature, 
# instead of vanila argmax
def sample(preds, temperature=1.0):
  if temperature <= 0:
    return np.argmax(preds)
  preds = np.asarray(preds).astype('float64')
  preds = np.log(preds) / temperature
  exp_preds = np.exp(preds)
  preds = exp_preds / np.sum(exp_preds)
  probas = np.random.multinomial(1, preds, 1)
  return np.argmax(probas)

In [None]:
# Method to generate next 'count' words if a 'text' is provided
def generate_next(lstm_model, text, count=10):
  word_idxs = [word2idx(word) for word in text.lower().split()]
  for i in range(count):
    prediction = lstm_model.predict(x=np.array(word_idxs))
    idx = sample(prediction[-1], temperature=0.7)
    word_idxs.append(idx)  
  return ' '.join(idx2word(idx) for idx in word_idxs)

In [None]:
# Creating a callback to save the LSTM model after every epoch
filepath="weights-improvement-{epoch:02d}-.hdf5"
checkpoint = ModelCheckpoint(filepath,verbose=1, save_best_only=False, save_weights_only=False, mode='auto')
callbacks_list = [checkpoint]

In [None]:
# Training the LSTM model
lstm_model.fit(train_x, train_y,
          batch_size=128,
          epochs=25,
          callbacks =callbacks_list, verbose=1)