# Imports

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import re
import random
from tqdm import tqdm
from tqdm import tnrange
from tqdm import tqdm_notebook


# Model

# v1 - Current
* One layer of LSTM
* Non-batch train

## Dataset Initialization

In [None]:
# Dataset
embed = hub.Module("https://tfhub.dev/google/nnlm-en-dim50/1")
string2idtable = tf.contrib.lookup.index_table_from_file(vocabulary_file="vocabulary-shakespeare.txt", num_oov_buckets=0)
id2stringtabel = tf.contrib.lookup.index_to_string_table_from_file(vocabulary_file="vocabulary-shakespeare.txt")
def _insertSpace(sentence):
    sentence = sentence.decode()
    sentence = sentence.lower()
    sentence = re.sub(r'([\W\d])', r' \1 ', sentence)
    return sentence
# Use tf.string_split if want to split string
# def _split(sentence):
#    return sentence.split()
def _getLabel(sentence):
    splited = tf.string_split([sentence]).values
    sentence = splited
    # Use " " as x_0
    # sentence = tf.concat([tf.constant([" "], dtype=tf.string), sentence[0:-1]], 0)
    sentence = sentence[0:-1]
    ids = string2idtable.lookup(splited)
    ids = tf.one_hot(ids, 11405)
    return sentence, ids# {"sentence":sentence}, ids
filenames = ["poems/shakespeare/sonnets.txt"]
# TODO Use skip and filter methods to preprocess data rather than manually do it
# TODO Use Dataset.map method to map '\n' to 'xxxnewlinexxx'
dataset = tf.data.TextLineDataset(filenames)
dataset = dataset.map(lambda sentence: tf.py_func(_insertSpace, [sentence], tf.string))
dataset = dataset.map(_getLabel)
dataset = dataset.shuffle(buffer_size=10000)
# dataset = dataset.batch(4)
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
next_poem = iterator.get_next()

## Model initialization and Graph building

In [None]:
# with tf.Graph().as_default():
# Graph
# Variables
"""
def get_reuse_var(name, shape, scope="dense", initializer=tf.random_normal_initializer):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        var = tf.get_variable(name, shape=shape, initializer=initializer)
    return var
softmax_w = get_reuse_var("softmax_w", shape=[400, 11405])
softmax_b = get_reuse_var("softmax_b", shape=[1, 11405])
"""
softmax_w = tf.get_variable("softmax_w", shape=[400, 11405], initializer=tf.random_normal_initializer)
softmax_b = tf.get_variable("softmax_b", shape=[1, 11405], initializer=tf.random_normal_initializer)

# TODO Use tuple for state
lstm = tf.contrib.rnn.LSTMCell(400, state_is_tuple=False, initializer=tf.random_normal_initializer, reuse=tf.AUTO_REUSE, name="LSTM1")

# Model
sentence, label = next_poem
sentence = embed(sentence)
sentence = tf.concat([tf.zeros([1, 50]), sentence], 0, name="Insert_X_0")


state = tf.placeholder(shape=[1, lstm.state_size], dtype=tf.float32, name="Previous_State")

x = tf.placeholder(shape=[50], dtype=tf.float32, name="Input_Word")
y = tf.placeholder(shape=[11405], dtype=tf.int32, name="Target_Word")
# TODO How to Access timestep?
input_word = tf.reshape(x, [1, 50])
output, out_state = lstm(input_word, state)
logits = tf.add(tf.matmul(output, softmax_w), softmax_b)
possibility = tf.nn.softmax(logits=logits, name="Possibilities")
word_index = tf.argmax(possibility, axis=-1, name="Predict_Word_index")
word = id2stringtabel.lookup(word_index)
loss_op = tf.losses.softmax_cross_entropy(onehot_labels=[y], logits=logits)
tf.summary.scalar(name="loss", tensor=loss_op)
merged = tf.summary.merge_all()
# loss_op = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits, name="Loss")

#total_loss = tf.get_variable(shape=[], dtype=tf.float32, name="Total_Cost", initializer=tf.zeros_initializer)

# TODO Train Ops
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss_op)

sess = tf.Session()
# with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.tables_initializer().run(session=sess)
sess.run(iterator.initializer)

# Trainning log
writer = tf.summary.FileWriter("tmp/log/", sess.graph)
global_step = 1

## Train Module

In [None]:
total_poems = 1000
for _ in range(total_poems):
    wordsVec, targets = sess.run([sentence, label])
    # TODO How to Access timestep?
    # Initial
    _, pre_state = sess.run([output, out_state], feed_dict={x: wordsVec[0], y: targets[0], state: np.zeros([1, lstm.state_size], dtype=float)})
    for i in range(1, wordsVec.shape[0]):
        # total_loss += loss_op
        # print(wordsVec[i].shape)
        _, pre_state, loss, summary, out = sess.run([train_op, out_state, loss_op, merged, word], feed_dict={x: wordsVec[i], y: targets[i], state: pre_state})
        writer.add_summary(summary, global_step=global_step)
        tf.logging.log_every_n(tf.logging.INFO, "Loss: %s | local step: %s | global step: %s | Output: %s", 100, loss, i, global_step, out)
        global_step += 1                             
print("Trainning poems: ", total_steps)

## Predict Module

In [None]:
def pretty(input_str):
    return re.sub("xxxnewlinexxx", "\n", input_str)
random_word = id2stringtabel.lookup(tf.constant([random.randint(0, 11404)], dtype=tf.int64))
word_predict, prev_state = sess.run([random_word, out_state], feed_dict={x: np.zeros([50]), state: np.zeros([1, lstm.state_size], dtype=float)})
print(word_predict)
poem = [word_predict[0].decode()]
while word_predict[0].decode() != "xxxendxxx":
    word_predict = sess.run(tf.reshape(embed(word_predict), shape=[50]))
    word_predict, prev_state = sess.run([word, out_state], feed_dict={x:word_predict, state: prev_state})
    poem.append(word_predict[0].decode())
    print(pretty(" ".join(poem)))
# Dump to generated dir
generated_poem = pretty(" ".join(poem))
filename = "-".join(poem[:5])
with open("generated/"+filename+".txt", "w") as f:
    f.write(generated_poem)

In [None]:
# Write graph to file
writer.close()

## Save Model

In [None]:
# TODO 

# v2 - Multi-RNN with Batch
* Deep LSTM
    - Try Different number of layers
    - Try Different LSTM units amount
* Batch
* Try other Embedding modules
    - https://tfhub.dev/google/nnlm-en-dim50/1
    - https://tfhub.dev/google/nnlm-en-dim50-with-normalization/1
    - https://tfhub.dev/google/nnlm-en-dim128/1
    - https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1
    - https://tfhub.dev/google/Wiki-words-250/1
    - https://tfhub.dev/google/Wiki-words-250-with-normalization/1
    - https://tfhub.dev/google/Wiki-words-500/1
    - https://tfhub.dev/google/Wiki-words-500-with-normalization/1