Implementation of multiple attention mechanisms#138
Conversation
| """ | ||
| taken from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | ||
| """ | ||
| L, S = query.size(-2), key.size(-2) |
There was a problem hiding this comment.
why do we specify L, S separately? Shouldn't the context size for query and key, i.e, query.size(-2) and key.size(-1), be always the same?
Also, what does L and S stand for?
| head_dim = n_embd // n_head_q | ||
| AttentionConfig(qkv_transforms=[]) | ||
|
|
||
| q = torch.rand(batch_size, n_head_q, block_size - 1, head_dim, dtype=torch.bfloat16).cuda() |
There was a problem hiding this comment.
As an idea: instead of torch.rand we could also do torch.arange(0,batch_size*n_head_q*(block_size-1)*head_dim).reshape(batch_size, n_head_q, block_size - 1, head_dim). In this case, we can check for equality instead of approx. equality.
| output_tensor[attention_impl_2], | ||
| atol=2.5e-3, # default for bfloat16: 1e-5 | ||
| rtol=0.016, # default for bfloat16: 0.016 | ||
| ) |
There was a problem hiding this comment.
We could add another test in which we test the output of the manual attention implementation against the precomputed output (pen and paper) given a very short input sequence. In this case, we can be entirely sure that the implementation is correct.
This PR implements
manual attentionandpytorch flash attention, in addition to the previously implementeddao flash attention. Group Query Attention is supported.