Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning attention weights in naive attention modules #589

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kklemon
Copy link

@kklemon kklemon commented Oct 4, 2023

Adds a return_attn_weights option to the forward method of the SelfAttention and CrossAttention modules, as well as to MHA and indirectly to Block.

Motivation

Since FlashAttention does not explicitly store a full attention matrix, it does not allow to access or extract attention weights. However, this can be useful or even required for many applications. As a workaround, the use of FA can be disabled on a possibly already pretrained model and then the proposed option used to access attention weights.

Discussion

Since the attention modules are usually deeply nested within Transformer-like architectures, the return_attn_weights argument and return values have to be propagated through the whole layer chain. This leads to a strong entanglement between the modules, as seen in the changes.

An alternative could be to implement the option for the low-level attention modules, i.e. Cross- and SelfAttention, similar as done in PyTorch for their MHA implementation, but not expose this in any upper layer. The attention weights would then need to be extracted via a forward hook. While this would complicate the process of extracting attention maps, it would reduce inter-layer dependencies and thus improve maintainability.

Since extracting attention maps does not appear to be a highly requested feature, this could be seen as the preferred solution.

Todos

  • Add tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant