Permalink
Browse files

code release

  • Loading branch information...
ArrasL
ArrasL committed Jun 29, 2017
1 parent f99e629 commit ee2fed449f6da59fe148dfac6d09349ef6e76003
Showing with 39 additions and 21 deletions.
  1. +3 −4 code/LSTM/LRP_linear_layer.py
  2. +2 −4 code/LSTM/LSTM_bidi.py
  3. +6 −8 code/util/heatmap.py
  4. +1 −0 data/README
  5. +0 −1 model/README
  6. +27 −4 run_example.ipynb
@@ -4,10 +4,9 @@
@date: 21.06.2017
@version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause
@license: BSD-2-Clause
'''
import numpy as np
from numpy import newaxis as na
@@ -19,10 +18,10 @@ def lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor, debug=Fal
- hin: forward pass input, of shape (D,)
- w: connection weights, of shape (D, M)
- b: biases, of shape (M,)
- hout: forward pass, of shape output (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!)
- hout: forward pass output, of shape (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!)
- Rout: relevance at layer output, of shape (M,)
- bias_nb_units: number of lower-layer units onto which the bias/stabilizer contribution is redistributed
- eps: stabilizer
- eps: stabilizer (small positive number)
- bias_factor: for global relevance conservation set to 1.0, otherwise 0.0 to ignore bias redistribution
Returns:
- Rin: relevance at layer input, of shape (D,)
View
@@ -4,10 +4,9 @@
@date: 21.06.2017
@version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause
@license: BSD-2-Clause
'''
import numpy as np
import pickle
from numpy import newaxis as na
@@ -16,7 +15,6 @@
class LSTM_bidi:
def __init__(self, model_path='./model/'):
# vocabulary
@@ -48,7 +46,7 @@ def __init__(self, model_path='./model/'):
def set_input(self, w, delete_pos=None):
"""
Build the numerical input x/x_rev from word sequence w (+ initialize hidden layers h, c)
Build the numerical input x/x_rev from word sequence indices w (+ initialize hidden layers h, c)
Optionally delete words at positions delete_pos.
"""
T = len(w) # input word sequence length
View
@@ -4,17 +4,15 @@
@date: 21.06.2017
@version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause
@license: BSD-2-Clause
'''
import matplotlib.pyplot as plt
def rescale_score_by_abs (score, max_score, min_score):
"""
rescale positive score to the range [0.5, 1.0], negative score to the range [0.0, 0.5],
using the extremal scores max_score and min_score for normalization
using the extremal scores max_score and min_score for normalization
"""
# CASE 1: positive AND negative scores occur --------------------
@@ -45,12 +43,12 @@ def rescale_score_by_abs (score, max_score, min_score):
return 0.0
else:
return 0.5 - 0.5*(score/min_score)
def getRGB (c_tuple):
return "#%02x%02x%02x"%(int(c_tuple[0]*255), int(c_tuple[1]*255), int(c_tuple[2]*255))
def span_word (word, score, colormap):
return "<span style=\"background-color:"+getRGB(colormap(score))+"\">"+word+"</span>"
View
@@ -1,3 +1,4 @@
- sequence_test.txt
contains the Stanford Sentiment Treebank test set sentences and labels (1=very negative, 2=negative, 3=neutral, 4=positive, 5=very positive) as raw text.
View
@@ -12,4 +12,3 @@ contains the word embeddings as numpy array
contains the bidirectional LSTM model weights as python dictionary of numpy arrays
View
@@ -27,7 +27,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define word sequence and target class"
"# Define word sequence and relevance target class"
]
},
{
@@ -117,7 +117,7 @@
"outputs": [],
"source": [
"eps = 0.001\n",
"bias_factor = 0.0\n",
"bias_factor = 0.0 # recommended value\n",
" \n",
"net = LSTM_bidi()\n",
"\n",
@@ -177,6 +177,29 @@
"display(HTML(html_heatmap(words, R_words)))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7314968678\n",
"True\n"
]
}
],
"source": [
"# sanity check \n",
"bias_factor = 1.0 # value for sanity check\n",
"Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)\n",
"R_tot = Rx.sum() + Rx_rev.sum() + R_rest.sum() # sum of all \"input\" relevances\n",
"\n",
"print(R_tot); print(np.allclose(R_tot, net.s[target_class]))# check relevance conservation"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -186,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"collapsed": true
},
@@ -203,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{

0 comments on commit ee2fed4

Please sign in to comment.