In [1]:
from code.LSTM.LSTM_bidi import * 
from code.util.heatmap import html_heatmap
import pickle
import codecs
import numpy as np
from IPython.display import display, HTML

This notebook performs LRP and SA on an exemplary word sequence, and for a chosen relevance target class. 

# Define word sequence and relevance target class

As a target class, you can freely choose one of the following classes:   0=very negative, 1=negative, 2=neutral, 3=positive, 4=very positive

As an input word sequence, either select a sentence from the test set, or define your own sequence.

In [2]:
# def get_test_sentence(sent_idx):
#     """Returns a test set sentence and its label, sent_idx must be an integer in [1, 2210]"""
#     idx = 1
#     with codecs.open("./data/sequence_test.txt", 'r', encoding='utf8') as f:
#         for line in f:
#             line          = line.rstrip('\n')
#             line          = line.split('\t')
#             label         = int(line[0])-1         # true sentence class
#             words         = line[1].split(' | ')   # sentence words
#             if idx == sent_idx:
#                 return words, label
#             idx +=1

In [3]:
# Select test sentence number 291:
# words, target_class = get_test_sentence(291)
# print(words)
# print(target_class)

In [4]:
# Uncomment to define own sequence (only words contained in the vocabulary are supported):
# words        = ['i','hate','the','movie','though','the','plot','is','interesting','.']
# words        = ['this','movie','was','actually','neither','that','funny',',','nor','super','witty','.']
# target_class =  0             


# ['he', 'is', 'a', 'good', 'boy']
# ['pps', 'bez', 'at', 'jj', 'nn']
# word id : 10366        tag id : 147
# word id : 3561        tag id : 326
# word id : 13320        tag id : 301
# word id : 7127        tag id : 75
# word id : 7956        tag id : 308


# Perform LRP

In [5]:
eps                 = 0.001
bias_factor         = 0.0                                               # recommended value
 
net                 = LSTM_bidi()

words = ['he', 'is', 'a', 'good', 'boy']

#w_indices           = [net.voc.index(w) for w in words]                 # convert to word IDs
w_indices = [10366, 3561, 13320, 7127, 7956]

# word_position = 1  # word = he
# target_class = 147 # tag = pps

# word_position = 2 # word = is
# target_class = 326 # tag = bez

# word_position = 3 # word = a
# target_class = 301 # tag = at


# word_position = 4 # word = good
# target_class = 75 # tag = jj


word_position = 5 # word = boy
target_class = 308 # tag = nn


Rx, Rx_rev, R_rest  = net.lrp(w_indices, word_position, target_class, eps, bias_factor)# LRP through the net
R_words             = np.sum(Rx + Rx_rev, axis=1)                       # word relevances

scores              = net.s.copy()                                      # classification 

(343,)
(343,)
(1, 343)
(64, 343)
(64, 1)
[[1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
  0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
  0. 0. 1. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 

In [6]:
print ("prediction scores:        ",   scores)
print ("\nLRP target class:         ", target_class)
print("\n word:          ",words[word_position-1])
print ("\nLRP relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words[idx]) + "\t" + w)
print ("\nLRP heatmap:")    
display(HTML(html_heatmap(words, R_words)))

prediction scores:         [array([-1.03138480e-01,  9.43769114e-02, -4.58944979e-01, -7.75402610e-01,
       -5.55988072e-01, -4.77012909e-01, -5.45808025e-01, -6.22646186e-01,
        2.72547859e-01, -1.96936252e-01,  8.91085819e-01,  4.07549514e-01,
        5.82683445e-01, -3.06583048e-01, -4.03596828e-01,  3.18838632e-01,
       -3.91735183e-01, -5.72699645e-01, -9.64054576e-02, -4.51539410e-01,
        3.29802027e-02, -4.12128307e-01,  1.50943286e-01,  2.11933593e-02,
       -3.38676668e-01, -2.84210909e-02,  2.93121721e-01, -7.23947730e-01,
       -9.74315517e-02, -3.97672242e-01, -7.74577247e-01, -3.70581865e-01,
       -2.06753446e-01,  2.72264419e-01, -3.25413508e-01, -3.76615827e-01,
        8.48676005e-02, -7.56467686e-01, -5.65967262e-01, -2.22721472e-01,
       -4.89893580e-01, -7.02874424e-01, -4.93342172e-01, -5.89525771e-01,
       -6.17537039e-01,  1.32928512e+00, -2.09462631e-01, -5.31194536e-01,
       -8.16769222e-02, -2.20573174e-01,  4.07922950e-01, -3.06962337e-0

In [7]:
# sanity check 
bias_factor        = 1.0                                    # value for sanity check
Rx, Rx_rev, R_rest = net.lrp(w_indices, word_position, target_class, eps, bias_factor)# net.lrp(w_indices, target_class, eps, bias_factor)
R_tot              = Rx.sum() + Rx_rev.sum() + R_rest.sum() # sum of all "input" relevances

print(R_tot); print(np.allclose(R_tot, net.s[target_class]))# check relevance conservation

(343,)
(343,)
(1, 343)
(64, 343)
(64, 1)
[[ 1. -1. -1. -1. -1. -1. -1. -1.  1. -1.  1. -1.  1. -1.  1. -1. -1. -1.
  -1. -1. -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.  1. -1. -1.
   1. -1. -1. -1. -1. -1. -1. -1. -1.  1. -1. -1. -1.  1.  1. -1. -1. -1.
  -1. -1. -1.  1. -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1.  1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.  1.  1.
  -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1. -1. -1. -1. -1. -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.  1. -1. -1. -1. -1. -1.
  -1. -1. -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
  -1.  1. -1. -1. -1. -1.  1. -1. -1. -1. -1.  1. -1. -1. -1. -1. -1. -1.
  -1. -1. -1. -1.  1. -1. -1. -1. -1. -1. -1. -1. -1.  1. -1. -1. -1. -

IndexError: list index out of range

# Perform SA

In [15]:
words = ['he', 'is', 'a', 'good', 'boy']

#w_indices           = [net.voc.index(w) for w in words]                 # convert to word IDs
w_indices = [10366, 3561, 13320, 7127, 7956]

# word_position = 1  # word = he
# target_class = 147 # tag = pps

# word_position = 2 # word = is
# target_class = 326 # tag = bez

# word_position = 3 # word = a
# target_class = 301 # tag = at


# word_position = 4 # word = good
# target_class = 75 # tag = jj

# word_position = 5 # word = boy
# target_class = 308 # tag = nn

word_position = 5 # word = boy
target_class = 308 # tag = nn

net              = LSTM_bidi()

#w_indices        = [net.voc.index(w) for w in words]                 # convert to word IDs
Gx, Gx_rev       = net.backward(w_indices, word_position, target_class)             # SA through the net
G_words          = (np.linalg.norm(Gx + Gx_rev, ord=2, axis=1))**2   # word relevances

scores           = net.s.copy()                                      # classification 

In [16]:
print ("prediction scores:       ",   scores)
print ("\nSA target class:         ", target_class)
print ("\nSA relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(G_words[idx]) + "\t" + w)
print ("\nSA heatmap:")    
display(HTML(html_heatmap(words, G_words)))

prediction scores:        [array([-1.03138480e-01,  9.43769114e-02, -4.58944979e-01, -7.75402610e-01,
       -5.55988072e-01, -4.77012909e-01, -5.45808025e-01, -6.22646186e-01,
        2.72547859e-01, -1.96936252e-01,  8.91085819e-01,  4.07549514e-01,
        5.82683445e-01, -3.06583048e-01, -4.03596828e-01,  3.18838632e-01,
       -3.91735183e-01, -5.72699645e-01, -9.64054576e-02, -4.51539410e-01,
        3.29802027e-02, -4.12128307e-01,  1.50943286e-01,  2.11933593e-02,
       -3.38676668e-01, -2.84210909e-02,  2.93121721e-01, -7.23947730e-01,
       -9.74315517e-02, -3.97672242e-01, -7.74577247e-01, -3.70581865e-01,
       -2.06753446e-01,  2.72264419e-01, -3.25413508e-01, -3.76615827e-01,
        8.48676005e-02, -7.56467686e-01, -5.65967262e-01, -2.22721472e-01,
       -4.89893580e-01, -7.02874424e-01, -4.93342172e-01, -5.89525771e-01,
       -6.17537039e-01,  1.32928512e+00, -2.09462631e-01, -5.31194536e-01,
       -8.16769222e-02, -2.20573174e-01,  4.07922950e-01, -3.06962337e-01