Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --xformers-flash-attention option to improve the reproducibility #6988

Merged

Conversation

takuma104
Copy link
Contributor

What does this PR do?

fig_with_wihout_flash_attn

Add the --xformers-flash-attention argument to enable the use of Flash Attention in xFormers. Using Flash Attention improves the reproducibility of SD image generation due to its deterministic behavior.

xFormer's reproducibility problem has been discussed in several issues over the last year. I think this comment is the most exhaustive list by @0xdevalias in #4011.

Limitation

Unfortunately, Flash Attention won't accept SD1.x attention shapes, but will accept SD2.x or variants.

Test

I have tested this PR code using the following code on Ubuntu. I don't have a Windows environment personally, so I would appreciate it if someone could test it on Windows.
https://gist.github.com/takuma104/58fbd99a02006c67dbb9ff968c7417f2

Test Environment

  • OS: Ubuntu 22.04
  • Browser: Chrome 109.0.5414.87 (on MacOS)
  • Graphics card: NVIDIA RTX 3060 12GB
  • xFormers Info:
$ python -m xformers.info
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
xFormers 0.0.15+3df785c.d20230111
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
swiglu.fused.p.cpp:                                available
is_triton_available:                               False
is_functorch_available:                            False
pytorch.version:                                   1.13.1
pytorch.cuda:                                      available
gpu.compute_capability:                            8.6
gpu.name:                                          NVIDIA GeForce RTX 3060

Discussion

Any suggestions are welcome.

  • The code now fallbacks to the default backend (usually Cutlass) if FlashAttention is not available. If a warning is given here, the SD1.x series will generate a huge amount of error logs. I think it would be better to give some sort of warning, but does anyone have a suggestion?

@takuma104 takuma104 changed the title add --xformers-flash-attention option & impl Add --xformers-flash-attention option to improve the reproducibility Jan 21, 2023
@Shondoit
Copy link
Contributor

I've been working on xformers reproducibility as well, so I'm looking forward on this being implemented.

That being said, have you tried this with xformers 0.0.14.dev0, the version currently being used by WebUI?
Because I saw an incredible amount of changes between 0.0.14 and the higher version.

@AUTOMATIC1111
Copy link
Owner

If this does not work on SD1 the usefulness of this is going to be extremely limited.

@AUTOMATIC1111
Copy link
Owner

I'm getting reproducible results on torch: 1.13.1+cu117 xformers: 0.0.16rc396, using SD1, no warnings or errors of any kind.

Have we found a reason to upgrade?

@Shondoit
Copy link
Contributor

In my short test of SD1 I've found that it seems to calls xformers_attention_forward 32 times per step.
Of which the first 8 can use flash attention.
Next 12 cannot, reason being ['max(query.shape[-1] != value.shape[-1]) > 128']
And the last 12 can.

I'm not well versed enough to know what those 32 calls are, but it seems that a lot of SD1 becomes at least a bit more reproducible.

For SD2 everything can use flash attention and greatly benefits from the reproducibility.

My recommendation would be to extract the code into it's own function

def get_supported_attention_op(q, k, v, attn_bias=None):
    op = None
    if shared.cmd_opts.xformers_flash_attention:
        op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
        fw, bw = op
        if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=attn_bias)):
            op = None
    return op
op = get_supported_attention_op(q, k, v, attn_bias=None)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)

@AUTOMATIC1111
Copy link
Owner

And on torch: 1.12.1+cu113  •  xformers: 0.0.14.dev this just breaks image generation with an error

@AUTOMATIC1111 AUTOMATIC1111 merged commit 3fa4820 into AUTOMATIC1111:master Jan 23, 2023
@Shondoit
Copy link
Contributor

@AUTOMATIC1111 I'm not sure why you merged the PR if it breaks when using the current dependencies?

Shouldn't launchy.py be ammended to include these new requirements? Currently it's still using old ones;

torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
...
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')

@AUTOMATIC1111
Copy link
Owner

because i included a commit by myself to make it not break

@Shondoit
Copy link
Contributor

Ah, okay.

So you're keeping the requirements to 11.3 for now, but inform people that flash attention is not used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants