Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Conversation

Warvito
Copy link
Collaborator

@Warvito Warvito commented Mar 23, 2023

Implements #339

Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
@Warvito Warvito linked an issue Mar 23, 2023 that may be closed by this pull request
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
@Warvito Warvito marked this pull request as ready for review March 24, 2023 07:27
@marksgraham marksgraham self-requested a review March 24, 2023 14:54
Copy link
Collaborator

@marksgraham marksgraham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Walter,

I've noticed some discrepancies between the two implementations. It would also be good to add tests for use_flash_attention, and maybe even a test that compares the output for the calculations with and without flash attention.

import torch.nn as nn
from torch.nn import functional as F

if importlib.util.find_spec("xformers") is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently if the code does not find xformers butuse_flash_attention is set True the code errors out. I think we need to self use_flash_attention=False in the init if has_xformers=False, and ideally raise a warning, too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added error message in case user want to use flash attention but xformers is not installed

Comment on lines 111 to 124
if self.use_flash_attention:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
y = xops.memory_efficient_attention(
query, key, value, attn_bias=xops.LowerTriangularMask() if self.causal else None
)

else:
# manual implementation of attention
attention_scores = (query @ key.transpose(-2, -1)) * self.scale

if self.causal:
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods give different values. scale isn't currently passed to memory_efficient_attention, nor the dropout probability, but even account for that it looks like that isn't the root of the difference. The python implementation of memory_efficient_attnetion looks a little different to the manual implementation here:

https://github.com/facebookresearch/xformers/blob/658ebab39545f180a6075385b3897921623d6c3b/xformers/ops/fmha/__init__.py#L142-L149

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this, I fixed the problem

import torch

from generative.networks.blocks import SABlock

device = torch.device("cuda")
sab = SABlock(hidden_size=4, num_heads=2, dropout_rate=0.0, use_flash_attention=False)
sab = sab.to(device)
sab.eval()

with torch.no_grad():
    x = torch.randn(1, 3, 4).to(device)
    result = sab(x)

    sab.use_flash_attention = True
    result_flash = sab(x)

torch.isclose(result, result_flash)

returning

tensor([[[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]], device='cuda:0')

Warvito added 2 commits March 25, 2023 12:15
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
@Warvito Warvito requested a review from marksgraham March 25, 2023 12:20
@Warvito
Copy link
Collaborator Author

Warvito commented Mar 25, 2023

@marksgraham I was thinking, since now Pytorch 2.0 has native flash attention, do you think it is worth to abandon the xformer dependency and adopt the native option instead? Since Pytorch 2.0 has compatibility, it would not be an issue to update to it in place of installing xformers (which requires Pytorch > 1.13)

@marksgraham
Copy link
Collaborator

Hi @Warvito

It would be great to use the native flash attention, but should we keep in the current xformers implementation for users that are on pytorch<2.0?

@Warvito
Copy link
Collaborator Author

Warvito commented Mar 27, 2023

Yes, I guess it would be okay

@Warvito Warvito merged commit b5ef5ff into main Mar 27, 2023
@Warvito Warvito deleted the 339-add-flash-attention-to-transformers branch March 31, 2023 18:25
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Flash attention to transformers

2 participants