Skip to content

Commit

Permalink
code release
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrasL committed Jun 29, 2017
1 parent f99e629 commit ee2fed4
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 21 deletions.
7 changes: 3 additions & 4 deletions code/LSTM/LRP_linear_layer.py
Expand Up @@ -4,10 +4,9 @@
@date: 21.06.2017 @date: 21.06.2017
@version: 1.0 @version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek @copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause @license: BSD-2-Clause
''' '''



import numpy as np import numpy as np
from numpy import newaxis as na from numpy import newaxis as na


Expand All @@ -19,10 +18,10 @@ def lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor, debug=Fal
- hin: forward pass input, of shape (D,) - hin: forward pass input, of shape (D,)
- w: connection weights, of shape (D, M) - w: connection weights, of shape (D, M)
- b: biases, of shape (M,) - b: biases, of shape (M,)
- hout: forward pass, of shape output (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!) - hout: forward pass output, of shape (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!)
- Rout: relevance at layer output, of shape (M,) - Rout: relevance at layer output, of shape (M,)
- bias_nb_units: number of lower-layer units onto which the bias/stabilizer contribution is redistributed - bias_nb_units: number of lower-layer units onto which the bias/stabilizer contribution is redistributed
- eps: stabilizer - eps: stabilizer (small positive number)
- bias_factor: for global relevance conservation set to 1.0, otherwise 0.0 to ignore bias redistribution - bias_factor: for global relevance conservation set to 1.0, otherwise 0.0 to ignore bias redistribution
Returns: Returns:
- Rin: relevance at layer input, of shape (D,) - Rin: relevance at layer input, of shape (D,)
Expand Down
6 changes: 2 additions & 4 deletions code/LSTM/LSTM_bidi.py
Expand Up @@ -4,10 +4,9 @@
@date: 21.06.2017 @date: 21.06.2017
@version: 1.0 @version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek @copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause @license: BSD-2-Clause
''' '''



import numpy as np import numpy as np
import pickle import pickle
from numpy import newaxis as na from numpy import newaxis as na
Expand All @@ -16,7 +15,6 @@


class LSTM_bidi: class LSTM_bidi:



def __init__(self, model_path='./model/'): def __init__(self, model_path='./model/'):


# vocabulary # vocabulary
Expand Down Expand Up @@ -48,7 +46,7 @@ def __init__(self, model_path='./model/'):


def set_input(self, w, delete_pos=None): def set_input(self, w, delete_pos=None):
""" """
Build the numerical input x/x_rev from word sequence w (+ initialize hidden layers h, c) Build the numerical input x/x_rev from word sequence indices w (+ initialize hidden layers h, c)
Optionally delete words at positions delete_pos. Optionally delete words at positions delete_pos.
""" """
T = len(w) # input word sequence length T = len(w) # input word sequence length
Expand Down
14 changes: 6 additions & 8 deletions code/util/heatmap.py
Expand Up @@ -4,17 +4,15 @@
@date: 21.06.2017 @date: 21.06.2017
@version: 1.0 @version: 1.0
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek @copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause @license: BSD-2-Clause
''' '''



import matplotlib.pyplot as plt import matplotlib.pyplot as plt



def rescale_score_by_abs (score, max_score, min_score): def rescale_score_by_abs (score, max_score, min_score):
""" """
rescale positive score to the range [0.5, 1.0], negative score to the range [0.0, 0.5], rescale positive score to the range [0.5, 1.0], negative score to the range [0.0, 0.5],
using the extremal scores max_score and min_score for normalization using the extremal scores max_score and min_score for normalization
""" """


# CASE 1: positive AND negative scores occur -------------------- # CASE 1: positive AND negative scores occur --------------------
Expand Down Expand Up @@ -45,12 +43,12 @@ def rescale_score_by_abs (score, max_score, min_score):
return 0.0 return 0.0
else: else:
return 0.5 - 0.5*(score/min_score) return 0.5 - 0.5*(score/min_score)


def getRGB (c_tuple): def getRGB (c_tuple):
return "#%02x%02x%02x"%(int(c_tuple[0]*255), int(c_tuple[1]*255), int(c_tuple[2]*255)) return "#%02x%02x%02x"%(int(c_tuple[0]*255), int(c_tuple[1]*255), int(c_tuple[2]*255))


def span_word (word, score, colormap): def span_word (word, score, colormap):
return "<span style=\"background-color:"+getRGB(colormap(score))+"\">"+word+"</span>" return "<span style=\"background-color:"+getRGB(colormap(score))+"\">"+word+"</span>"


Expand Down
1 change: 1 addition & 0 deletions data/README
@@ -1,3 +1,4 @@
- sequence_test.txt - sequence_test.txt


contains the Stanford Sentiment Treebank test set sentences and labels (1=very negative, 2=negative, 3=neutral, 4=positive, 5=very positive) as raw text. contains the Stanford Sentiment Treebank test set sentences and labels (1=very negative, 2=negative, 3=neutral, 4=positive, 5=very positive) as raw text.

1 change: 0 additions & 1 deletion model/README
Expand Up @@ -12,4 +12,3 @@ contains the word embeddings as numpy array


contains the bidirectional LSTM model weights as python dictionary of numpy arrays contains the bidirectional LSTM model weights as python dictionary of numpy arrays



31 changes: 27 additions & 4 deletions run_example.ipynb
Expand Up @@ -27,7 +27,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Define word sequence and target class" "# Define word sequence and relevance target class"
] ]
}, },
{ {
Expand Down Expand Up @@ -117,7 +117,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"eps = 0.001\n", "eps = 0.001\n",
"bias_factor = 0.0\n", "bias_factor = 0.0 # recommended value\n",
" \n", " \n",
"net = LSTM_bidi()\n", "net = LSTM_bidi()\n",
"\n", "\n",
Expand Down Expand Up @@ -177,6 +177,29 @@
"display(HTML(html_heatmap(words, R_words)))" "display(HTML(html_heatmap(words, R_words)))"
] ]
}, },
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7314968678\n",
"True\n"
]
}
],
"source": [
"# sanity check \n",
"bias_factor = 1.0 # value for sanity check\n",
"Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)\n",
"R_tot = Rx.sum() + Rx_rev.sum() + R_rest.sum() # sum of all \"input\" relevances\n",
"\n",
"print(R_tot); print(np.allclose(R_tot, net.s[target_class]))# check relevance conservation"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
Expand All @@ -186,7 +209,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
Expand All @@ -203,7 +226,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
Expand Down

0 comments on commit ee2fed4

Please sign in to comment.