Skip to content

Implementation of multiple attention mechanisms#138

Merged
le1nux merged 10 commits intodev_experimentsfrom
feat/multiple_attention_implementations
Jun 10, 2024
Merged

Implementation of multiple attention mechanisms#138
le1nux merged 10 commits intodev_experimentsfrom
feat/multiple_attention_implementations

Conversation

@flxst
Copy link
Copy Markdown
Member

@flxst flxst commented May 27, 2024

This PR implements manual attention and pytorch flash attention, in addition to the previously implemented dao flash attention. Group Query Attention is supported.

@flxst flxst requested review from le1nux and mali-git May 27, 2024 15:12
@flxst flxst changed the title Implementation of multiple attention implementations Implementation of multiple attention mechanisms May 27, 2024
@le1nux le1nux added the enhancement New feature or request label May 29, 2024
@flxst flxst requested a review from fromm-m June 3, 2024 13:01
"""
taken from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
"""
L, S = query.size(-2), key.size(-2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we specify L, S separately? Shouldn't the context size for query and key, i.e, query.size(-2) and key.size(-1), be always the same?
Also, what does L and S stand for?

head_dim = n_embd // n_head_q
AttentionConfig(qkv_transforms=[])

q = torch.rand(batch_size, n_head_q, block_size - 1, head_dim, dtype=torch.bfloat16).cuda()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an idea: instead of torch.rand we could also do torch.arange(0,batch_size*n_head_q*(block_size-1)*head_dim).reshape(batch_size, n_head_q, block_size - 1, head_dim). In this case, we can check for equality instead of approx. equality.

output_tensor[attention_impl_2],
atol=2.5e-3, # default for bfloat16: 1e-5
rtol=0.016, # default for bfloat16: 0.016
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add another test in which we test the output of the manual attention implementation against the precomputed output (pen and paper) given a very short input sequence. In this case, we can be entirely sure that the implementation is correct.

Copy link
Copy Markdown
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some minor comments. But I think we can merge it already! Really nice how we can switch now between the different attention implementations and also have them fully tested. Nice work! :)

@le1nux le1nux merged commit cb65b87 into dev_experiments Jun 10, 2024
@le1nux le1nux deleted the feat/multiple_attention_implementations branch June 10, 2024 21:48
@le1nux le1nux mentioned this pull request Jun 11, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants