Support returning attention weights in naive attention modules #589
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adds a
return_attn_weights
option to theforward
method of theSelfAttention
andCrossAttention
modules, as well as toMHA
and indirectly toBlock
.Motivation
Since FlashAttention does not explicitly store a full attention matrix, it does not allow to access or extract attention weights. However, this can be useful or even required for many applications. As a workaround, the use of FA can be disabled on a possibly already pretrained model and then the proposed option used to access attention weights.
Discussion
Since the attention modules are usually deeply nested within Transformer-like architectures, the
return_attn_weights
argument and return values have to be propagated through the whole layer chain. This leads to a strong entanglement between the modules, as seen in the changes.An alternative could be to implement the option for the low-level attention modules, i.e.
Cross-
andSelfAttention
, similar as done in PyTorch for their MHA implementation, but not expose this in any upper layer. The attention weights would then need to be extracted via a forward hook. While this would complicate the process of extracting attention maps, it would reduce inter-layer dependencies and thus improve maintainability.Since extracting attention maps does not appear to be a highly requested feature, this could be seen as the preferred solution.
Todos