-
Notifications
You must be signed in to change notification settings - Fork 102
Add flash attention to Transformers #342
Conversation
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Walter,
I've noticed some discrepancies between the two implementations. It would also be good to add tests for use_flash_attention
, and maybe even a test that compares the output for the calculations with and without flash attention.
import torch.nn as nn | ||
from torch.nn import functional as F | ||
|
||
if importlib.util.find_spec("xformers") is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently if the code does not find xformers
butuse_flash_attention
is set True
the code errors out. I think we need to self use_flash_attention=False
in the init if has_xformers=False
, and ideally raise a warning, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added error message in case user want to use flash attention but xformers is not installed
if self.use_flash_attention: | ||
query = query.contiguous() | ||
key = key.contiguous() | ||
value = value.contiguous() | ||
y = xops.memory_efficient_attention( | ||
query, key, value, attn_bias=xops.LowerTriangularMask() if self.causal else None | ||
) | ||
|
||
else: | ||
# manual implementation of attention | ||
attention_scores = (query @ key.transpose(-2, -1)) * self.scale | ||
|
||
if self.causal: | ||
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two methods give different values. scale
isn't currently passed to memory_efficient_attention
, nor the dropout probability, but even account for that it looks like that isn't the root of the difference. The python implementation of memory_efficient_attnetion
looks a little different to the manual implementation here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing this, I fixed the problem
import torch
from generative.networks.blocks import SABlock
device = torch.device("cuda")
sab = SABlock(hidden_size=4, num_heads=2, dropout_rate=0.0, use_flash_attention=False)
sab = sab.to(device)
sab.eval()
with torch.no_grad():
x = torch.randn(1, 3, 4).to(device)
result = sab(x)
sab.use_flash_attention = True
result_flash = sab(x)
torch.isclose(result, result_flash)
returning
tensor([[[True, True, True, True],
[True, True, True, True],
[True, True, True, True]]], device='cuda:0')
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
@marksgraham I was thinking, since now Pytorch 2.0 has native flash attention, do you think it is worth to abandon the |
Hi @Warvito It would be great to use the native flash attention, but should we keep in the current xformers implementation for users that are on pytorch<2.0? |
Yes, I guess it would be okay |
Implements #339