Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to propropagate individual hidden layer relevance scores of attention through LSTM? #8

Closed
sharan21 opened this issue Dec 4, 2019 · 1 comment

Comments

@sharan21
Copy link

sharan21 commented Dec 4, 2019

My model consists of an encoder LSTM, an attention layer and a Linear decoder layer for the task of binary classification. So far I have propagated LRP all the way till the hidden layer inputs to the attention layer and am not sure how to propagate each hidden layer relevance to the input layer through the encoder LSTM.

This repo only assumes that the model is a simple encoder LSTM, and a linear decoder which takes the final hidden state as input to produce the output class, if I am right.

How can I propagate these individual hidden layer scores throught the LSTM using this approach? If I only try and propagate the last hidden layer scores through the LSTM using this code it 1. doesn't take the other hidden state scores into account 2. assumes that the attention layer only takes the last hidden state as the input.

I understand that this may be an open question, any help/advice on how to proceed will be greatly appreciated.

@ArrasL
Copy link
Owner

ArrasL commented Mar 27, 2020

Hi Sharan,

sorry for the late answer! (I will try to be thorough to compensate ;-))

Well. You can indeed backward propagate LRP relevances through attention layers by using the same strategy as we introduced in our original paper (namely the signal-take-all strategy) for the product between attention weights and hidden states.

Let me give you some general hints on how to proceed. First you need to understand that all operations/layers present in most recurrent neural networks essentially boil down to three basic operations:

  • linear layers (this includes fully-connected layers, and element-wise summation)
  • product between two neurons (where one neuron is a signed "signal" neuron, and the other is a "gate" or "attention weight" whose value is in the range [0, 1])
  • element-wise non-linear activation (tanh, sigmoid or softmax, where the former is used for "signal" neurons, and the latter is used for "gates" or "attention weights" to control/modulate the flow of information).

In the LRP backward pass, each of these layers can be handled in the following way:

  • on linear layers the LRP eps-rule can be applied (to redistribute relevance in proportion to forward pass contributions)
  • for products the signal-take-all redistribution rule can be employed: it will redistribute the relevance of the product, which was assigned to the product from higher-layer neurons, entirely onto the "signal" neuron.
  • through the element-wise tanh layer the relevance is backward propagated as the identity

With these rules at hand, you can write your own custom LRP backward pass for any recurrent neural network.

Let me be more precise concerning attention layers.
Typical attention layers in recurrent neural networks (such as in Bahdanau et al. 2015 or Luong et al. 2015) contain a summation of terms of the form a_i \cdot h_i , i.e. a product of two neurons, where a_i is the attention weight (also called alignment weight), it's the softmax activated neuron and its value is in the range [0, 1], and h_i is a neuron from an encoder hidden state.
During the LRP backward pass, first you need to determine the relevance of the product term (a_i \cdot h_i), this can be achieved by using the eps-rule in the summation layer.
Then, how to redistribute this quantity to the neurons a_i and h_i? Well, using the signal-take-all strategy, the entire relevance of the product goes to the hidden state neuron h_i (and nothing to a_i). The whole process (1. backward LRP through sum layer, 2. backward LRP through product layer) amounts to treat the attention weight as a standard connection weight in a simple linear layer, which intuitively makes sense since the key idea underlying the attention mechanism is that the hidden state shall be the "value of interest", and the attention weight shall be just a "reweighting factor" in the weighted summation of hidden states for different time steps.

In practice, for an LSTM model with attention, this means the relevance of the hidden layer states h_t comes from two relevance "message" sources which add up: 1) the standard backward computation graph through time in the LSTM model, 2) the attention layer. So in this implementation you would need in particular to change lines 216 and 224 for the LSTM left encoder to account for the upward relevance quantity coming from the attention layer.

In any case, when implementing the LRP backward pass on your model, I would highly recommend that you sanity check your implementation. You can do so by performing a LRP backward pass with a specific mode, where you redistribute the contributions from bias and stabilizer (i.e. by setting bias_factor to 1.0 here): with this mode you should have exact numerical conservation of the relevance between the model's output prediction score, and the sum of relevances of all input neurons (i.e. including the initial hidden and cell states at time 0, as this was done in this notebook cell 9). This way you can be sure that your LRP implementation is correct.

Hope that helps! Good luck with your project!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants