# Prediction Visualizer

This notebook demonstrates how to plot the sentence/word attentions for each one of the documents in the test set, for debugging purposes.

In [23]:
import srsly
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display, HTML

In [5]:
    PREDICTIONS_PATH = "../logdir/ag_news_glove_run/bs=32,lr=7.4e-04,end_lr=0e0/ie_dirs/glove_run.jsonl"

In [6]:
predictions = list(srsly.read_jsonl(PREDICTIONS_PATH))

In [18]:
# referenced to https://github.com/sharkmir1/Hierarchical-Attention-Network/blob/master/utils.py
def map_sentence_to_color(words, scores, sent_score):
    """
    :param words: array of words
    :param scores: array of attention scores each corresponding to a word
    :param sent_score: sentence attention score
    :return: html formatted string
    """

    sentencemap = matplotlib.cm.get_cmap('binary')
    wordmap = matplotlib.cm.get_cmap('OrRd')
    result = '<p><span style="margin:5px; padding:5px; background-color: {}">'\
        .format(matplotlib.colors.rgb2hex(sentencemap(sent_score)[:3]))
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    for word, score in zip(words, scores):
        color = matplotlib.colors.rgb2hex(wordmap(score)[:3])
        result += template.format(color, '&nbsp' + word + '&nbsp')
    result += '</span><p>'
    return result

In [27]:
def visualize(i):
    item = predictions[i]
    
    original_document = item["original_document"]
    word_att_weights = item["word_att_weights"]
    sentence_att_weights = item["sentence_att_weights"]

    print(f"True Label: {item['true_label']}")
    print(f"Predicted Label: {item['predicted_label']}")

    result = "<h2>Attention Visualization</h2>"
    for orig_sent, att_weight, sent_weight in zip(original_document, word_att_weights, sentence_att_weights):
        result += map_sentence_to_color(orig_sent, att_weight, sent_weight)

    display(HTML(result))

In [33]:
visualize(200)

True Label: Sci/Tech
Predicted Label: Sci/Tech
