# 5) RNN

## Model - RNN

[http://www.aclweb.org/anthology/P16-2034](http://www.aclweb.org/anthology/P16-2034)

Finally, a bidirectional RNN with attention. RNNs are nice for how they model sequences, but are very hard to interpret. This is why attention is useful. It adds some interpretability to these models, as the attention coefficients indicate the important parts of the input.

In [None]:
max_seq_len = 32

def embed_and_pad(utterances, embeddings, max_len):
    embedded = np.zeros((utterances.shape[0], max_len, embeddings.vector_size))

    for i, utterance in enumerate(utterances):
        tokens = utterance.split(" ")
        for j in range(max_len):
            if j < len(tokens) and tokens[j] in embeddings.vocab:
                embedded[i, j, :] = embeddings.get_vector(tokens[j])
            else:
                embedded[i, j, :] = np.zeros((embeddings.vector_size))
    return embedded

train_x_rnn = embed_and_pad(train_x_cnn, embeddings, max_seq_len)
test_x_rnn = embed_and_pad(test_x_cnn, embeddings, max_seq_len)

In [None]:
reload(rnn)
tf.reset_default_graph()

rnn_model = rnn.RNN(embeddings, num_major_characters, hidden_size=64, max_len=max_seq_len)

In [None]:
display_step = 20
num_epochs = 10
batch_size = 128

train_writer = tf.summary.FileWriter("./rnn-board/train")
test_writer = tf.summary.FileWriter("./rnn-board/test")
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer.add_graph(sess.graph)
    
    global_step = 0
    X = rnn_model.X
    Y = rnn_model.Y
    keep_prob = rnn_model.keep_prob

    for epoch in range(num_epochs):
        print("{} epoch number: {}".format(datetime.now(), epoch + 1))
        train_x_rnn, train_y_cnn = shuffle(train_x_rnn, train_y_cnn)
        
        for step, x_batch, y_batch in batch_iter(train_x_rnn, train_y_cnn, batch_size):
            global_step += 1
            sess.run(rnn_model.train_op, feed_dict={ X: x_batch, Y: y_batch, keep_prob: 1 })

            # every so often, report the progress of our loss and training accuracy
            if step % display_step == 0:
                summ = sess.run(rnn_model.merged_summary, feed_dict={ X: x_batch, Y: y_batch })
                train_writer.add_summary(summ, global_step=global_step)
                train_writer.flush()

        test_acc, summ = sess.run([rnn_model.accuracy, rnn_model.merged_summary], feed_dict={ X: test_x_rnn, Y: test_y_cnn })
        test_writer.add_summary(summ, global_step=global_step)
        test_writer.flush()
        print("test accuracy = {:.4f}".format(test_acc))
    
    # save the model to disk so we can load it up later for use by `./eval.py`
    saver.save(sess, "./rnn-ckpt/model.ckpt")

### RNN Evaluation

In [None]:
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "./rnn-ckpt/model.ckpt")
    rnn_test_attn, rnn_test_predictions = sess.run((rnn_model.alpha, rnn_model.predictions), feed_dict={ rnn_model.X: test_x_rnn })
    
print(classification_report(test_y_cnn, rnn_test_predictions, target_names=major_characters))

conf_matrix_plot(confusion_matrix(test_y_cnn, rnn_test_predictions), major_characters)

In [None]:
def visualize_attention(sentence, attn, max_len=None):
    attn = attn / np.max(attn)
    tagged_sentence = []
    for i, word in enumerate(sentence):
        opacity = attn[i] if i < len(attn) else 0
        tagged_sentence.append(f'<span style="background-color: rgba(100, 200, 255, {opacity});">{word}</span>')
        if max_len and i == max_len - 1:
            tagged_sentence.append('<span style="color: red;">|</span>')
    display(Markdown(" ".join(tagged_sentence)))

In [None]:
user_input_to_classify = "i love you chandler do you want to go to the restaurant"

rnn_user_input_x = embed_and_pad(np.array([user_input_to_classify]), embeddings, max_seq_len)

with tf.Session() as sess:
    saver.restore(sess, "./rnn-ckpt/model.ckpt")

    rnn_user_attention, (rnn_user_input_pred,) = sess.run((rnn_model.alpha, rnn_model.predictions), feed_dict={ rnn_model.X: rnn_user_input_x })
    
print(major_characters[rnn_user_input_pred])
visualize_attention(user_input_to_classify.split(" "), rnn_user_attention[0], max_len=max_seq_len)

In [None]:
utterances_to_plot = []
attentions_to_plot = []
original_indices = []

for i in range(num_major_characters):
    display(Markdown("## " + major_characters[i]))
    correct_predictions_i = np.where(np.logical_and(test_y_cnn == i, rnn_test_predictions == i))
    for j in correct_predictions_i[0][:200]:
        sentence = test_x_cnn[j].split(" ")
        if len(sentence) > 1:
            visualize_attention(sentence, rnn_test_attn[j], max_len=max_seq_len)